├── .gitignore
├── README.md
├── build.sh
├── config
├── datasets
│ ├── eval
│ │ ├── AMBER_desc.yaml
│ │ ├── CLEVR_val1k_type_opt.yaml
│ │ ├── EmbSpa_test_opt.yaml
│ │ ├── GQA.yaml
│ │ ├── MMBench_DEV_opt.yaml
│ │ ├── POPE_adversarial.yaml
│ │ ├── SEED_opt.yaml
│ │ ├── VSR_TEST.yaml
│ │ ├── VStar.yaml
│ │ ├── Wino.yaml
│ │ ├── clevr_ref.yaml
│ │ ├── refcocog_test-u.yaml
│ │ └── refcocog_val-u.yaml
│ ├── stage1_alignment.yaml
│ ├── stage2_interleaved.yaml
│ └── stage3_instruct.yaml
├── deepspeed
│ ├── config_zero2.json
│ ├── config_zero2_offload.json
│ ├── config_zero3.json
│ └── config_zero3_offload.json
└── experiments
│ ├── stage1_alignment.yaml
│ ├── stage2_interleaved.yaml
│ └── stage3_instruct.yaml
├── constants.py
├── conversation.py
├── eval
├── Evaluation.md
├── __init__.py
├── commands
│ ├── run_metric.sh
│ └── test_all_benchmark.sh
├── eval_tools
│ ├── clevr.py
│ ├── convert_res_to_amber.py
│ ├── convert_res_to_gqa.py
│ ├── convert_res_to_gqa_llava.py
│ ├── convert_res_to_mmbench.py
│ ├── gqa.py
│ ├── m4c_evaluator.py
│ ├── mmbench.py
│ ├── mp3d.py
│ ├── pope.py
│ ├── refcoco.py
│ ├── seed.py
│ ├── vsr.py
│ └── vstar.py
├── evaluate_benchmark.py
└── merge_benchmark.py
├── figs
├── model_arch.png
└── sample_input.jpg
├── locals
└── datasets
│ ├── __init__.py
│ ├── dataloader.py
│ ├── eval
│ ├── qa.py
│ └── short_qa.py
│ ├── image_caption
│ ├── cc3m.py
│ ├── coco.py
│ ├── flicker30k.py
│ ├── fusecap.py
│ ├── grit.py
│ ├── lvis_gpt4v.py
│ └── mmc4.py
│ ├── image_edit
│ ├── edit_dataset
│ │ └── zipped_datasets.py
│ ├── low_level
│ │ ├── clwd.py
│ │ ├── gopro.py
│ │ ├── reds.py
│ │ └── sidd.py
│ ├── prompt
│ │ ├── color_list_train_small.txt
│ │ ├── prompt_deblur.txt
│ │ ├── prompt_denoise.txt
│ │ ├── prompt_dewatermark.txt
│ │ ├── prompt_pose.txt
│ │ └── prompt_seg.txt
│ └── seg
│ │ ├── coco_stuff.py
│ │ ├── grefcoco.py
│ │ ├── grefcoco_seg.py
│ │ ├── refcoco.py
│ │ └── refcoco_seg.py
│ ├── multimodal_tasks
│ ├── cot_qa.py
│ ├── llava_R.py
│ ├── llava_academic.py
│ ├── lvis_instruct4v.py
│ ├── m3it.py
│ ├── object_detect.py
│ ├── refcoco.py
│ ├── sharegpt4v.py
│ ├── single_image_base.py
│ ├── spatial.py
│ └── svit.py
│ ├── preprocessor.py
│ ├── prompts
│ ├── prompt_captioning_long.txt
│ ├── prompt_captioning_short.txt
│ ├── prompt_kosmosg.txt
│ ├── prompt_ref_i2t.txt
│ ├── prompt_ref_t2i.txt
│ └── prompt_txt2img.txt
│ ├── text
│ ├── sharegpt.py
│ ├── text_data_base.py
│ ├── txt_cot.py
│ └── ultrachat.py
│ ├── text2image
│ ├── kosmosg.py
│ └── midjourney.py
│ └── utils
│ ├── box_utils.py
│ └── zip_manager.py
├── model
├── front_projector
│ ├── Qformer.py
│ └── builder.py
├── language_model
│ ├── volcano_base.py
│ ├── volcano_llama.py
│ └── volcano_mistral.py
├── load_model.py
└── vision_encoder
│ ├── builder.py
│ ├── clip_encoder.py
│ ├── eva_vit.py
│ └── eva_vit_emu.py
├── requirements.txt
├── train_volcano.py
├── utils
├── __init__.py
├── __init__.pyc
├── count_line.py
├── eval_util.py
├── format_utils.py
├── gqa_inpaint.py
├── llava_flash_attn.py
├── logger.py
├── logger.pyc
├── time_check.py
└── util.py
└── vocot_trainer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__
2 | .vscode
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🌋VoCoT: Unleashing Visually Grounded Multi-Step Reasoning in Large Multi-Modal Models [[Paper](https://arxiv.org/abs/2405.16919)]
2 |
3 |
4 |
5 |
6 | Exploring visually CoT reasoning in large multi-modal models with VoCoT!
7 |
8 | [Zejun Li*](https://github.com/Junction4Nako), [Ruipu Luo*](https://github.com/RupertLuo), Jiwen Zhang, Minghui Qiu, Zhongyu Wei (*Equal Contribution)
9 |
10 | We propose Visually grounded object-centric Chain-of-Thoughts (VoCoT) to support effective and reliable multi-step reasoning in large multi-modal models. For more details, please refer to our paper.
11 |
12 | In this repository, we will release:
13 | - The constructed VoCoT-Instruct data that can be use to train your multi-modal models to reason in the VoCoT format.
14 | - VolCano model with the ability to perform multi-step reasoning in VoCoT.
15 | - Training scripts utilized to train VolCano.
16 | - Evaluation datasets and scripts used in our paper.
17 |
18 | ## Contents
19 | - [Getting Started](#getting-started)
20 | - [Data](#data)
21 | - [Model Weights](#model-weights)
22 | - [Train](#train)
23 | - [Evaluation](#evaluation)
24 |
25 | ## Getting Started
26 |
27 | ### Install
28 |
29 | 1. Clone this repository
30 | ```bash
31 | git clone https://github.com/RupertLuo/VoCoT.git
32 | cd VoCoT
33 | ```
34 |
35 | 2. Install packages
36 | ```bash
37 | conda create -n vocot python=3.9 -y
38 | conda activate vocot
39 | pip3 install --upgrade pip
40 | pip3 install -r requirements.txt
41 | pip3 install flash-attn --no-build-isolation
42 | sudo apt-get install python3-tk -y
43 | ```
44 |
45 | ### Quick Start
46 |
47 | ```python
48 | from model.load_model import load_model, infer
49 | from PIL import Image
50 |
51 | # loading the model
52 | model_path = 'luoruipu1/Volcano-7b'
53 | model, preprocessor = load_model(model_path, precision='fp16')
54 |
55 | # perform reasoning, activate VoCoT by passing cot=True
56 | input_image = Image.open('figs/sample_input.jpg')
57 | response_1 = infer(model, preprocessor, input_image, 'Is there a event "the cat is below the bed" in this image?', cot=True)
58 | response_2 = infer(model, preprocessor, input_image, 'Why is the cat on the bed?', cot=True)
59 | response_3 = infer(model, preprocessor, input_image, 'Describe the image.', cot=True)
60 | print('response 1: ', response_1[0])
61 | print('response 2: ', response_2[0])
62 | print('response 3: ', response_3[0])
63 | ```
64 | Notice: in the default setting, the output coordinates is the box in the image which is expanded to square, not the original image.
65 |
66 | ## Data
67 |
68 | For users who want to use VoCoT-Instruct to train their own models, we provide a integrated json file, [VoCoT-Instruct-80K](https://huggingface.co/datasets/luoruipu1/VoCoT/blob/main/VoCoT-80K_integrated.json) following the conversation format of LLaVA (Notice that all coordinates are for the images that are expanded to square).
69 |
70 | If you would like to follow the training of VolCano in this paper, please use the separate json files in the [raw_data](https://huggingface.co/datasets/luoruipu1/VoCoT/tree/main/raw_data) for efficient dataset management.
71 |
72 | For the corresponding images, please visit [GQA](https://cs.stanford.edu/people/dorarad/gqa/about.html), [COCO](https://cocodataset.org/), and [LVIS](https://www.lvisdataset.org/) to download the images.
73 |
74 | ## Model Weights
75 |
76 | The VolCano model is based on Mistral-Instruct-v0.2-7B and CLIP-14/L, the connection the trained weights are released [here](https://huggingface.co/luoruipu1/Volcano-7b/tree/main). The architecture is illustrated below and details are included in our paper.
77 |
78 |
79 | 
80 | The architecture of VolCano.
81 |
82 |
83 | ## Train
84 |
85 | In each stage, you prepare the data and pre-trained checkpoints with the help of following instructions:
86 |
87 | ### Prepare the Data
88 |
89 | The datasets are managed with yaml files in [config/datasets](./config/datasets/). In each yaml, you need to make sure all the paths with "PATH/TO" in them are correctly prepared.
90 |
91 | For datasets that are either introduced, modified, or filtered in this paper, e.g. VoCoT-Instruct, the subset of GRIT and MMC4, we provide the meta data in [here](https://huggingface.co/datasets/luoruipu1/VoCoT/tree/main/pretrain).
92 |
93 | For public datasets like RefCOCO, LLaVA, and ALLaVA, please refer to the corresponding websites to obtain the data.
94 | If it is not clear where to obtain the dataset, feel free to contact us.
95 |
96 | ### Conduct Training
97 |
98 | We manage the experiment settings with yaml files in [config/experiments](./config/experiments/). In each yaml, set the paths with "PATH/TO" to your correct paths.
99 |
100 | Every stage can be launched with (replace the yaml path with corresponding paths):
101 |
102 | ```bash
103 | # modify the torchrun config to fit your machine
104 | torchrun --nproc_per_node=8 train_volcano.py --conf config/experiments/stage1_alignment.yaml
105 | ```
106 |
107 | ## Evaluation
108 | Please see [Evaluation](./eval/Evaluation.md) for details about the evaluation datasets and evaluation scripts.
109 |
110 | ## Acknowledgement
111 | We thank the following open-source resources which we referenced during the development of VoCoT.
112 |
113 | - [LLaVA](https://github.com/haotian-liu/LLaVA): our codebase is built on LLaVA and LLaVA-Instruct is adopted during the training.
114 | - [Shikra](https://github.com/shikras/shikra): we follow Shikra for the construction of GQA Type-1 VoCoT-Instruct data.
115 | - [InstructDiffusion](https://github.com/cientgu/InstructDiffusion): we referenced InstructDiffusion codebase to manage multiple datasets.
116 |
117 |
118 | ## Citation
119 |
120 | If you find our code, data, or model useful during your work, please consider citing our work:
121 | ```bibtex
122 | @article{li2024vocot,
123 | title={VoCoT: Unleashing Visually Grounded Multi-Step Reasoning in Large Multi-Modal Models},
124 | author={Li, Zejun and Luo, Ruipu and Zhang, Jiwen and Qiu, Minghui and Wei, Zhongyu},
125 | journal={arXiv preprint arXiv:2405.16919},
126 | year={2024}
127 | }
128 | ```
129 |
--------------------------------------------------------------------------------
/build.sh:
--------------------------------------------------------------------------------
1 | pip3 install --upgrade pip
2 | pip3 install -r requirements.txt
3 | pip3 install flash-attn --no-build-isolation
4 | sudo apt-get install python3-tk -y
--------------------------------------------------------------------------------
/config/datasets/eval/AMBER_desc.yaml:
--------------------------------------------------------------------------------
1 | - mmvet:
2 | target: locals.datasets.eval.short_qa.AMBERDataset
3 | params:
4 | path: PATH/TO/AMBER/AMBER/data/query/query_generative.json
5 | base_path: PATH/TO/AMBER/image
6 | describe: True
--------------------------------------------------------------------------------
/config/datasets/eval/CLEVR_val1k_type_opt.yaml:
--------------------------------------------------------------------------------
1 | - mmvet:
2 | target: locals.datasets.eval.qa.CLEVRDataset
3 | params:
4 | path: PATH/TO/CLEVR/CLEVR_v1.0/clevr_val_4choices_1k_per_type.json
5 | image_dir: PATH/TO/CLEVR_v1.0/images/val/
--------------------------------------------------------------------------------
/config/datasets/eval/EmbSpa_test_opt.yaml:
--------------------------------------------------------------------------------
1 | - mmvet:
2 | target: locals.datasets.eval.qa.EmbSpatialDataset
3 | params:
4 | path: PATH/TO/embspatial_bench_v1.json
5 | image_dir: PATH/TO/emb_spatial/
--------------------------------------------------------------------------------
/config/datasets/eval/GQA.yaml:
--------------------------------------------------------------------------------
1 | - gqa:
2 | target: locals.datasets.eval.short_qa.GQADataset
3 | params:
4 | path: PATH/TO/GQA/testdev_balanced_questions.json
5 | base_path: PATH/TO/GQA/images/
--------------------------------------------------------------------------------
/config/datasets/eval/MMBench_DEV_opt.yaml:
--------------------------------------------------------------------------------
1 | - mmvet:
2 | target: locals.datasets.eval.qa.MMBenchOptDataset
3 | params:
4 | path: PATH/TO/mmbench/MMBench_DEV_EN_legacy.tsv
5 | single_pred_prompt: False
--------------------------------------------------------------------------------
/config/datasets/eval/POPE_adversarial.yaml:
--------------------------------------------------------------------------------
1 | - pope_adv:
2 | target: locals.datasets.eval.short_qa.POPEDataset
3 | params:
4 | path: PATH/TO/coco_pope_adversarial.jsonl
5 | base_path: PATH/TO/COCO2015/images/
--------------------------------------------------------------------------------
/config/datasets/eval/SEED_opt.yaml:
--------------------------------------------------------------------------------
1 | - mmvet:
2 | target: locals.datasets.eval.qa.SEEDOptionDataset
3 | params:
4 | path: PATH/TO/SEED-Bench/SEED-Bench.json
5 | image_dir: PATH/TO/SEED-Bench/SEED-Bench-image
--------------------------------------------------------------------------------
/config/datasets/eval/VSR_TEST.yaml:
--------------------------------------------------------------------------------
1 | - mme:
2 | target: locals.datasets.eval.short_qa.VSRDataset
3 | params:
4 | path: PATH/TO/VSR/zeroshot/test.jsonl
5 | base_path: PATH/TO/COCO2017/
--------------------------------------------------------------------------------
/config/datasets/eval/VStar.yaml:
--------------------------------------------------------------------------------
1 | - mmvet:
2 | target: locals.datasets.eval.qa.VStarDataset
3 | params:
4 | path: PATH/TO/VStar/meta.json
5 | image_dir: PATH/TO/VStar/vstar_hub_dl
--------------------------------------------------------------------------------
/config/datasets/eval/Wino.yaml:
--------------------------------------------------------------------------------
1 | - mmvet:
2 | target: locals.datasets.eval.qa.WinoTextDataset
3 | params:
4 | option_in_context: False
5 |
--------------------------------------------------------------------------------
/config/datasets/eval/clevr_ref.yaml:
--------------------------------------------------------------------------------
1 | - refcoco:
2 | target: locals.datasets.multimodal_tasks.refcoco.ClevrRefDataset
3 | params:
4 | path: PATH/TO/clevr_ref/meta_1item_2k.json
5 | image_path: PATH/TO/clevr_ref+_1.0/images/val
--------------------------------------------------------------------------------
/config/datasets/eval/refcocog_test-u.yaml:
--------------------------------------------------------------------------------
1 | - refcoco:
2 | target: locals.datasets.multimodal_tasks.refcoco.RefCOCOEvalDataset
3 | params:
4 | path: PATH/TO/Refcoco
5 | dataset_name: refcocog
6 | split: test
7 | split_by: umd
8 | image_path: PATH/TO/COCO2015/images/
9 | task_mode: t2i
--------------------------------------------------------------------------------
/config/datasets/eval/refcocog_val-u.yaml:
--------------------------------------------------------------------------------
1 | - refcoco:
2 | target: locals.datasets.multimodal_tasks.refcoco.RefCOCOEvalDataset
3 | params:
4 | path: /mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/Refcoco
5 | dataset_name: refcocog
6 | split: val
7 | split_by: umd
8 | image_path: PATH/TO/COCO2015/images/
9 | task_mode: t2i
--------------------------------------------------------------------------------
/config/datasets/stage1_alignment.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | target: locals.datasets.DataModuleFromConfig
3 | params:
4 | batch_size: 64
5 | num_workers: 4
6 | wrap: True
7 | train:
8 | - llava_pretrain_i2t:
9 | target: locals.datasets.image_caption.cc3m.FilteredCC3MI2TDataset
10 | params:
11 | path: /PATH/TO/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json
12 | image_folder: /PATH/TO/LLaVA-Pretrain/images
13 | raw_image: True
14 | output_mode: text
15 | shuffle: False
--------------------------------------------------------------------------------
/config/datasets/stage2_interleaved.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | target: locals.datasets.DataModuleFromConfig
3 | params:
4 | batch_size: 64
5 | num_workers: 4
6 | wrap: True
7 | train:
8 | - allava_laion:
9 | target: locals.datasets.multimodal_tasks.llava_academic.ALLaVACaptionDataset
10 | params:
11 | path: PATH/TO/allava_laion/ALLaVA-Caption-LAION-4V.json
12 | image_folder: PATH/TO/ALLaVA/images
13 | raw_image: True
14 | output_mode: text
15 | shuffle: False
16 | expand2square: True
17 | - allava_vflan:
18 | target: locals.datasets.multimodal_tasks.llava_academic.ALLaVACaptionDataset
19 | params:
20 | path: PATH/TO/allava_vflan/ALLaVA-Caption-VFLAN-4V.json
21 | image_folder: PATH/TO/ALLaVA/images
22 | raw_image: True
23 | output_mode: text
24 | shuffle: False
25 | expand2square: True
26 | - mmc4:
27 | target: locals.datasets.image_caption.mmc4.FilteredMMC4Dataset
28 | params:
29 | path: PATH/TO/filter_mmc4_meta_with_img_abs_path_890k.jsonl
30 | avoid_image_gen: True
31 | expand2square: True
32 | - grit:
33 | target: locals.datasets.image_caption.grit.GriTDataset
34 | params:
35 | path: PATH/TO/clip_filtered_756K_grit.json
36 | image_folder: PATH/TO/GRIT/images/
37 | raw_image: True
38 | output_mode: text
39 | shuffle: True
40 | phrase_key: noun_chunks
41 | avoid_image_gen: True
42 | phrase_format: 'text'
43 | phrase_prec: 3
44 | object_format: 'representation'
45 | expand2square: True
46 | - flickr30kentities:
47 | target: locals.datasets.image_caption.flicker30k.FlickrDataset
48 | params:
49 | path: PATH/TO/CWB_flickr30k_train.jsonl
50 | image_folder: PATH/TO/Flicker30K/flickr30k-images/
51 | avoid_image_gen: True
52 | phrase_format: 'text'
53 | phrase_prec: 3
54 | object_format: 'representation'
55 | expand2square: True
--------------------------------------------------------------------------------
/config/datasets/stage3_instruct.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | target: locals.datasets.DataModuleFromConfig
3 | params:
4 | batch_size: 64
5 | num_workers: 4
6 | wrap: True
7 | train:
8 | - llava_academic:
9 | target: locals.datasets.multimodal_tasks.llava_academic.LlavaAcademicDataset
10 | params:
11 | path: PATH/TO/llava_v1_5_mix665k_norefcoco.json # removing refcoco from llava 665K
12 | image_folder: PATH/TO/LLaVA_images
13 | raw_image: True
14 | output_mode: conversation
15 | avoid_image_gen: True
16 | min_size: 50
17 | phrase_format: 'text'
18 | phrase_prec: 3
19 | expand2square: True
20 | object_format: 'representation'
21 | - refcoco:
22 | target: locals.datasets.multimodal_tasks.refcoco.RefCOCODataset
23 | params:
24 | path: PATH/TO/Refcoco
25 | dataset_name: refcoco
26 | split: train
27 | image_path: PATH/TO/COCO2015/images/
28 | task_mode: both
29 | avoid_image_gen: True
30 | phrase_format: 'text'
31 | phrase_prec: 3
32 | expand2square: True
33 | object_format: 'representation'
34 | - refcoco+:
35 | target: locals.datasets.multimodal_tasks.refcoco.RefCOCODataset
36 | params:
37 | path: PATH/TO/Refcoco
38 | dataset_name: refcoco+
39 | split: train
40 | image_path: PATH/TO/COCO2015/images/
41 | task_mode: both
42 | avoid_image_gen: True
43 | phrase_format: 'text'
44 | phrase_prec: 3
45 | expand2square: True
46 | object_format: 'representation'
47 | - refcocog:
48 | target: locals.datasets.multimodal_tasks.refcoco.RefCOCODataset
49 | params:
50 | path: PATH/TO/Refcoco
51 | dataset_name: refcocog
52 | split: train
53 | split_by: umd
54 | image_path: PATH/TO/COCO2015/images/
55 | task_mode: both
56 | avoid_image_gen: True
57 | phrase_format: 'text'
58 | phrase_prec: 3
59 | expand2square: True
60 | object_format: 'representation'
61 | - grefcoco:
62 | target: locals.datasets.multimodal_tasks.refcoco.GRefCOCODataset
63 | params:
64 | path: PATH/TO/Refcoco
65 | split: train
66 | image_path: PATH/TO/COCO2015/images/
67 | task_mode: both
68 | avoid_image_gen: True
69 | phrase_format: 'text'
70 | phrase_prec: 3
71 | expand2square: True
72 | object_format: 'representation'
73 | - shikra_cot_gen:
74 | target: locals.datasets.multimodal_tasks.cot_qa.CoTQADataset
75 | params:
76 | path: PATH/TO/Shikra/GPT4GEN_BoxCoT_train.jsonl # See Shikra data
77 | image_path: PATH/TO/Flicker30K/flickr30k-images/
78 | avoid_image_gen: True
79 | phrase_format: 'text'
80 | phrase_prec: 3
81 | expand2square: True
82 | object_format: 'representation'
83 | further_instruct: True
84 | - shikra_rd:
85 | target: locals.datasets.multimodal_tasks.cot_qa.CoTQADataset
86 | params:
87 | path: PATH/TO/Shikra/GPT4GEN_RD_BoxCoT_train.jsonl # See Shikra data
88 | image_path: PATH/TO/Flicker30K/flickr30k-images/
89 | avoid_image_gen: True
90 | phrase_format: 'text'
91 | phrase_prec: 3
92 | expand2square: True
93 | object_format: 'representation'
94 | further_instruct: False
95 | - cot_gqa:
96 | target: locals.datasets.multimodal_tasks.cot_qa.GQACoTDataset
97 | params:
98 | path: PATH/TO/raw_data/type1_gqa_raw.jsonl
99 | image_path: PATH/TO/GQA/images
100 | avoid_image_gen: True
101 | phrase_format: 'text'
102 | phrase_prec: 3
103 | expand2square: True
104 | object_format: 'representation'
105 | further_instruct: True
106 | sample_weight: 0.75
107 | - llava_QA2T:
108 | target: locals.datasets.multimodal_tasks.llava_academic.LlavaQA2TDataset
109 | params:
110 | path: PATH/TO/raw_data/type2_iqa2t_raw.json
111 | raw_image: True
112 | output_mode: conversation
113 | avoid_image_gen: True
114 | min_size: 50
115 | phrase_format: 'text'
116 | phrase_prec: 3
117 | expand2square: True
118 | object_format: 'representation'
119 | - lvis_I2QTA:
120 | target: locals.datasets.multimodal_tasks.llava_academic.LlavaI2QTADataset
121 | params:
122 | path: PATH/TO/raw_data/type3_i2qta_raw.json
123 | raw_image: True
124 | output_mode: conversation
125 | avoid_image_gen: True
126 | min_size: 50
127 | phrase_format: 'text'
128 | phrase_prec: 3
129 | expand2square: True
130 | object_format: 'representation'
131 | block_invalid: True
--------------------------------------------------------------------------------
/config/deepspeed/config_zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 |
14 | "zero_optimization": {
15 | "stage": 2,
16 | "allgather_partitions": true,
17 | "allgather_bucket_size": 5e8,
18 | "overlap_comm": true,
19 | "reduce_scatter": true,
20 | "reduce_bucket_size": 5e8,
21 | "contiguous_gradients": true
22 | },
23 |
24 | "gradient_accumulation_steps": "auto",
25 | "gradient_clipping": "auto",
26 | "steps_per_print": 2000,
27 | "train_batch_size": "auto",
28 | "train_micro_batch_size_per_gpu": "auto",
29 | "wall_clock_breakdown": false
30 | }
--------------------------------------------------------------------------------
/config/deepspeed/config_zero2_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 |
14 | "zero_optimization": {
15 | "stage": 2,
16 | "offload_optimizer": {
17 | "device": "cpu",
18 | "pin_memory": true
19 | },
20 | "offload_param": {
21 | "device": "cpu",
22 | "pin_memory": true
23 | },
24 | "allgather_partitions": true,
25 | "allgather_bucket_size": 5e8,
26 | "overlap_comm": true,
27 | "reduce_scatter": true,
28 | "reduce_bucket_size": 5e8,
29 | "contiguous_gradients": true
30 | },
31 |
32 | "gradient_accumulation_steps": "auto",
33 | "gradient_clipping": "auto",
34 | "steps_per_print": 2000,
35 | "train_batch_size": "auto",
36 | "train_micro_batch_size_per_gpu": "auto",
37 | "wall_clock_breakdown": false
38 | }
--------------------------------------------------------------------------------
/config/deepspeed/config_zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
--------------------------------------------------------------------------------
/config/deepspeed/config_zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "scheduler": {
23 | "type": "WarmupLR",
24 | "params": {
25 | "warmup_min_lr": "auto",
26 | "warmup_max_lr": "auto",
27 | "warmup_num_steps": "auto"
28 | }
29 | },
30 | "zero_optimization": {
31 | "stage": 3,
32 | "offload_optimizer": {
33 | "device": "cpu",
34 | "pin_memory": true
35 | },
36 | "offload_param": {
37 | "device": "cpu",
38 | "pin_memory": true
39 | },
40 | "overlap_comm": true,
41 | "contiguous_gradients": true,
42 | "sub_group_size": 1e9,
43 | "reduce_bucket_size": "auto",
44 | "stage3_prefetch_bucket_size": "auto",
45 | "stage3_param_persistence_threshold": "auto",
46 | "stage3_max_live_parameters": 1e9,
47 | "stage3_max_reuse_distance": 1e9,
48 | "gather_16bit_weights_on_model_save": true
49 | },
50 | "gradient_accumulation_steps": "auto",
51 | "gradient_clipping": "auto",
52 | "train_batch_size": "auto",
53 | "train_micro_batch_size_per_gpu": "auto",
54 | "steps_per_print": 1e5,
55 | "wall_clock_breakdown": false
56 | }
--------------------------------------------------------------------------------
/config/experiments/stage1_alignment.yaml:
--------------------------------------------------------------------------------
1 | project_name: volcano
2 | run_name: volcano_stage2
3 | # Whether to make the system prompt a mask in the label, and others do not mask
4 | only_mask_system: False
5 | # system prompt style
6 | conv_mode: v1
7 | # wether lora
8 | lora_enable: False
9 | # wether multimodal
10 | is_multimodal: True
11 |
12 | freeze_backbone: True
13 |
14 | # weight path
15 | model_path: mistralai/Mistral-7B-Instruct-v0.2
16 | vision_encoder: openai/clip-vit-large-patch14-336
17 | vision_encoder_path: openai/clip-vit-large-patch14-336
18 | skip_vision_encoder_load: False
19 | front_projector_type: mlp2x_gelu
20 | num_query_token: 32
21 | avoid_generator: True # do not use generator in this stage
22 | output_dir: /PATH/TO/STAGE1_CKPT
23 | behind_projector: linear
24 | flash_attn: True
25 | # dataset config
26 | data_config_path: config/datasets/stage1_aligment.yaml
27 | expand_to_square: True
28 | remove_unused_columns: False
29 | regression_weight: 1.0
30 | tokenizer_model_max_length: 2048
31 | model_max_length: 2048
32 | extend_loc_vocabulary: False
33 | use_mistral: True
34 |
35 | num_train_epochs: 1
36 | per_device_train_batch_size: 16
37 | save_strategy: 'steps'
38 | lora_save_strategy: steps # if do lora training, turn on this button, to only save lora weight. support ['steps','epochs','no']
39 | save_steps: 6000
40 | learning_rate: 1e-5
41 | gradient_checkpointing: True
42 | # wether do fast epoch
43 | fast_epoch: False
44 |
45 | # whether to compute diffusion loss
46 | compute_diffusion_loss: False
47 |
48 | bf16: True
49 | fp16: False
50 | tf32: False
51 | per_device_eval_batch_size: 1
52 | gradient_accumulation_steps: 1
53 | evaluation_strategy: "no"
54 | save_total_limit: 3
55 | weight_decay: 0.
56 | warmup_ratio: 0.0
57 | lr_scheduler_type: cosine
58 | logging_steps: 1
59 | model_max_length: 2048
60 | adam_beta1: 0.9
61 | adam_beta2: 0.95
62 | deepspeed: config/deepspeed/config_zero2.json
63 | dataloader_num_workers: 4
64 | # report_to: wandb
65 | is_training: True
--------------------------------------------------------------------------------
/config/experiments/stage2_interleaved.yaml:
--------------------------------------------------------------------------------
1 | project_name: volcano
2 | run_name: volcano_stage2
3 | # Whether to make the system prompt a mask in the label, and others do not mask
4 | only_mask_system: False
5 | # system prompt style
6 | conv_mode: v1
7 | # wether lora
8 | lora_enable: False
9 | # wether multimodal
10 | is_multimodal: True
11 |
12 | freeze_backbone: False
13 |
14 | # weight path
15 | model_path: mistralai/Mistral-7B-Instruct-v0.2
16 | vision_encoder: openai/clip-vit-large-patch14-336
17 | vision_encoder_path: openai/clip-vit-large-patch14-336
18 | skip_vision_encoder_load: False
19 | front_projector_type: mlp2x_gelu
20 | num_query_token: 32
21 | avoid_generator: True # do not use generator in this stage
22 | output_dir: PATH/TO/STAGE2_CKPT
23 | behind_projector: linear
24 | flash_attn: True
25 | # dataset config
26 | data_config_path: config/datasets/stage2_interleaved.yaml
27 | expand_to_square: True
28 | remove_unused_columns: False
29 | regression_weight: 1.0
30 | tokenizer_model_max_length: 2048
31 | model_max_length: 2048
32 | extend_loc_vocabulary: False
33 | use_mistral: True
34 | stage1_ckpt: PATH/TO/STAGE1_CKPT_PROJECTION # please only contain the projection layer
35 |
36 | num_train_epochs: 1
37 | per_device_train_batch_size: 16
38 | save_strategy: 'steps'
39 | lora_save_strategy: steps # if do lora training, turn on this button, to only save lora weight. support ['steps','epochs','no']
40 | save_steps: 6000
41 | learning_rate: 1e-5
42 | gradient_checkpointing: True
43 | # wether do fast epoch
44 | fast_epoch: False
45 |
46 | # whether to compute diffusion loss
47 | compute_diffusion_loss: False
48 |
49 | bf16: True
50 | fp16: False
51 | tf32: False
52 | per_device_eval_batch_size: 1
53 | gradient_accumulation_steps: 1
54 | evaluation_strategy: "no"
55 | save_total_limit: 3
56 | weight_decay: 0.
57 | warmup_ratio: 0.0
58 | lr_scheduler_type: cosine
59 | logging_steps: 1
60 | model_max_length: 2048
61 | adam_beta1: 0.9
62 | adam_beta2: 0.95
63 | deepspeed: config/deepspeed/config_zero2.json
64 | dataloader_num_workers: 4
65 | # report_to: wandb
66 | is_training: True
--------------------------------------------------------------------------------
/config/experiments/stage3_instruct.yaml:
--------------------------------------------------------------------------------
1 | project_name: volcano
2 | run_name: volcano_stage3
3 | # Whether to make the system prompt a mask in the label, and others do not mask
4 | only_mask_system: False
5 | # system prompt style
6 | conv_mode: v1
7 | # wether lora
8 | lora_enable: False # no lora in the stage-1
9 | # wether multimodal
10 | is_multimodal: True
11 |
12 | freeze_backbone: False
13 |
14 | # weight path
15 | model_path: PATH/TO/STAGE2_CKPT
16 | vision_encoder: openai/clip-vit-large-patch14-336
17 | vision_encoder_path: openai/clip-vit-large-patch14-336
18 | skip_vision_encoder_load: False
19 | front_projector_type: mlp2x_gelu
20 | num_query_token: 32
21 | avoid_generator: True # do not use generator in this stage
22 | output_dir: PATH/TO/STAGE3_CKPT
23 | behind_projector: linear
24 | flash_attn: True
25 | # dataset config
26 | data_config_path: config/datasets/stage3_instruct.yaml
27 | expand_to_square: True
28 | remove_unused_columns: False
29 | regression_weight: 1.0
30 | tokenizer_model_max_length: 3072
31 | model_max_length: 3072
32 | extend_loc_vocabulary: False
33 | use_mistral: True
34 |
35 | num_train_epochs: 1
36 | per_device_train_batch_size: 16
37 | save_strategy: 'steps'
38 | lora_save_strategy: steps # if do lora training, turn on this button, to only save lora weight. support ['steps','epochs','no']
39 | save_steps: 3000
40 | learning_rate: 1e-5
41 | gradient_checkpointing: True
42 | # wether do fast epoch
43 | fast_epoch: False
44 |
45 | # whether to compute diffusion loss
46 | compute_diffusion_loss: False
47 |
48 | bf16: True
49 | fp16: False
50 | tf32: False
51 | per_device_eval_batch_size: 1
52 | gradient_accumulation_steps: 1
53 | evaluation_strategy: "no"
54 | save_total_limit: 3
55 | weight_decay: 0.
56 | warmup_ratio: 0.0
57 | lr_scheduler_type: cosine
58 | logging_steps: 1
59 | model_max_length: 3072
60 | adam_beta1: 0.9
61 | adam_beta2: 0.95
62 | deepspeed: config/deepspeed/config_zero2.json
63 | dataloader_num_workers: 4
64 | # report_to: wandb
65 | is_training: True
--------------------------------------------------------------------------------
/constants.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 |
4 |
5 | IMG_TOKEN_NUM = 8
6 | ALL_IMG_TOKENS = [f"[IMG{i}]" for i in range(IMG_TOKEN_NUM)]
7 | ALL_IMG_TOKENS_STR = '' # "
"# "".join(ALL_IMG_TOKENS)
8 |
9 | # Location related tokens
10 | LOC_TOKEN_NUM = 256
11 | ALL_LOC_TOKENS = ["[LOC{}]".format(i+1) for i in range(LOC_TOKEN_NUM)]
12 |
13 | USE_PREFIX_TUNING = False
14 | USE_LORA = False
15 | USE_CFG = True
16 | IGNORE_TOKEN_ID = -100
17 | IMAGE_TOKEN_INDEX = -200
18 |
19 |
20 | PRECISION = torch.bfloat16
21 | TRAINABLE_PRECISION = torch.float32
22 |
23 | IGNORE_INDEX = -100
24 | DEFAULT_GRD_TOKEN = ""
25 | DEFAULT_BOP_TOKEN = "" # begin of phrase modified from 3/11
26 | DEFAULT_EOP_TOKEN = "" # end of phrase modified from 3/11
27 | DEFAULT_BOC_TOKEN = "" # begin of coordinates
28 | DEFAULT_EOC_TOKEN = "" # end of coordinates
29 | DEFAULT_SEP_TOKEN = " and"# "" modified from 3/11
30 | DEFAULT_PAD_TOKEN = "[PAD]"
31 | DEFAULT_EOS_TOKEN = ""
32 | DEFAULT_BOS_TOKEN = ""
33 | DEFAULT_UNK_TOKEN = ""
34 |
35 | # begin of image
36 | DEFAULT_BOI_TOKEN = "
"
37 | # end of image
38 | DEFAULT_EOI_TOKEN = ''
39 | # default image token
40 | DEFAULT_IMG_TOKEN = ''
41 | COT_ACTIVATION = 'Answer the question and include the reasoning proess. Locate key objects and provide bounding boxes in your thoughts.'
42 | COT_ACTIVATION_TXT = 'Answer the question and include the reasoning proess.'
--------------------------------------------------------------------------------
/eval/Evaluation.md:
--------------------------------------------------------------------------------
1 | # Evaluation of VolCano
2 |
3 | ## Data Prepare
4 | For evaluation, we also utilize the yaml file to manage the evaluation datasets.
5 |
6 | For CLEVR and CLEVR-ref, we provide the constructed evaluation meta data in [here](https://huggingface.co/datasets/luoruipu1/VoCoT/tree/main/eval). For other dataets, please refer to the original website for the data.
7 |
8 | With the prepared datasets, please set the correct paths in those [config files](../config/datasets/eval/).
9 |
10 | ## Run Evaluation
11 |
12 | We provide the evaluation scripts in [test_all_benchmark.sh](./commands/test_all_benchmark.sh), you can modify the model path and run the entire script or use a part of it.
13 |
14 | ## Metric Computation
15 |
16 | Similar to the evaluation, all the metrics can be computed offline with [run_metric.sh](./commands/run_metric.sh). For GQA and AMBER, the output is converted into appropriate format. You need to further compute the metric, please refer to [LLaVA_for_GQA](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md#gqa) and [AMBER](https://github.com/junyangwang0410/AMBER) for further instruction.
17 |
--------------------------------------------------------------------------------
/eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RupertLuo/VoCoT/d8177ca2e8e303f7d35609a000e5eff7bec319cb/eval/__init__.py
--------------------------------------------------------------------------------
/eval/commands/run_metric.sh:
--------------------------------------------------------------------------------
1 | ##################### total setting=2,4,6,7
2 | function run_all(){
3 | model_name=/mnt/bn/yangmin-priv/luoruipu/checkpoints/LLaVA-clip336px-obj-represent-Mistral-1e-5-3072-instruct_llava+shikraCoT+GPTQTA-Block+lvis-cot/
4 |
5 |
6 | store_model_name=llava_mistral_instruct_simplified_image_llava_shikraCoT75+GPTQTA-qa2t
7 |
8 | ## dataset setting
9 | function mmbench(){
10 | ##### mci
11 | dataset_name=mmbench
12 | dataset_config=config/datasets/eval/MMBench_DEV_opt.yaml
13 | output_dir=output/${dataset_name}/${store_model_name}/cot_sum/
14 | echo "========MMBench Result=========="
15 | python3 eval/eval_tools/mmbench.py --config ${dataset_config} --result ${output_dir}/MMBench_DEV_opt.json
16 | }
17 |
18 | function seed(){
19 | ##### mci
20 | dataset_name=seed
21 | dataset_config=config/datasets/eval/SEED_opt.yaml
22 | output_dir=output/${dataset_name}/${store_model_name}/cot_sum/
23 | echo "========SEED Result=========="
24 | python3 eval/eval_tools/seed.py --config ${dataset_config} --result ${output_dir}/SEED_opt.json
25 | }
26 |
27 | function clevr(){
28 | ##### mci
29 | dataset_name=clevr
30 | dataset_config=config/datasets/eval/CLEVR_val1k_type_opt.yaml
31 | output_dir=output/${dataset_name}/${store_model_name}/cot_sum/
32 | echo "========CLEVR Result=========="
33 | python3 eval/eval_tools/clevr.py --config config/datasets/eval/CLEVR_val1k_type_opt.yaml --result ${output_dir}/CLEVR_val1k_type_opt.json
34 | }
35 |
36 |
37 | function embspatial(){
38 | ##### mci
39 | dataset_name=emb_spa
40 | dataset_config=config/datasets/eval/EmbSpa_test_opt.yaml
41 | output_dir=output/${dataset_name}/${store_model_name}/cot_sum/
42 | echo "========EmbSpatial Result==========="
43 | python3 eval/eval_tools/mp3d.py --config config/datasets/eval/EmbSpa_test_opt.yaml --result ${output_dir}/EmbSpa_test_opt.json
44 | }
45 |
46 | function vsr(){
47 | ##### mci
48 | dataset_name=vsr
49 | dataset_config=config/datasets/eval/VSR_TEST.yaml
50 | output_dir=output/${dataset_name}/${store_model_name}/cot/
51 | echo "========VSR Result==========="
52 | python3 eval/eval_tools/vsr.py --config config/datasets/eval/VSR_TEST.yaml --data ${output_dir}/VSR_TEST.json
53 | }
54 |
55 | function pope(){
56 | ##### mci
57 | dataset_name=pope
58 | dataset_config=config/datasets/eval/POPE_adversarial.yaml
59 | output_dir=output/${dataset_name}/${store_model_name}/cot/
60 | echo "========POPE Result==========="
61 | python3 eval/eval_tools/pope.py --data ${output_dir}/POPE_adversarial.json
62 | }
63 |
64 | function vstar(){
65 | ##### mci
66 | dataset_name=vstar
67 | dataset_config=config/datasets/eval/VStar.yaml
68 | output_dir=output/${dataset_name}/${store_model_name}/cot/
69 | echo "========VSTAR Result==========="
70 | python3 eval/eval_tools/vstar.py --result ${output_dir}/VStar.json
71 | }
72 |
73 | function wino(){
74 | ##### mci
75 | dataset_name=wino
76 | dataset_config=config/datasets/eval/Wino.yaml
77 | output_dir=output/${dataset_name}/${store_model_name}/cot_no_instruct/
78 | echo "========WINO-txt Result==========="
79 | python3 eval/eval_tools/vstar.py --result ${output_dir}/Wino.json
80 | }
81 |
82 | function amber(){
83 | dataset_name=amber
84 | dataset_config=config/datasets/eval/AMBER_desc.yaml
85 | output_dir=output/${dataset_name}/${store_model_name}/cot_sum/
86 | echo "========AMBER Result Need Further Evaluation==========="
87 | python3 eval/eval_tools/convert_res_to_amber.py \
88 | --src ${output_dir}/AMBER_desc.json \
89 | --tgt ${output_dir}/AMBER_desc_eval.json \
90 | --desc
91 | }
92 |
93 | function gqa(){
94 | dataset_name=gqa
95 | dataset_config=config/datasets/eval/GQA.yaml
96 | output_dir=output/${dataset_name}/${store_model_name}/cot/
97 | echo "========GQA Result Need Further Evaluation==========="
98 | python3 eval/eval_tools/convert_res_to_gqa.py \
99 | --src ${output_dir}/GQA.json \
100 | --dst ${output_dir}/testdev_balanced_predictions.json
101 | }
102 |
103 | function refcocog_test(){
104 | ##### mci
105 | dataset_name=refcoco
106 | dataset_config=config/datasets/eval/refcocog_test-u.yaml
107 | output_dir=output/${dataset_name}/${store_model_name}/llava_prompt/
108 | echo "========RefCOCOg test Result==========="
109 | python3 eval/eval_tools/refcoco.py --mistral --path ${output_dir}/refcocog_test-u.json
110 | }
111 | function refcocog_val(){
112 | ##### mci
113 | dataset_name=refcoco
114 | dataset_config=config/datasets/eval/refcocog_val-u.yaml
115 | output_dir=output/${dataset_name}/${store_model_name}/llava_prompt/
116 | echo "========RefCOCOg val Result==========="
117 | python3 eval/eval_tools/refcoco.py --mistral --path ${output_dir}/refcocog_val-u.json
118 | }
119 | function clevr_ref(){
120 | dataset_name=refcoco
121 | dataset_config=config/datasets/eval/clevr_ref.yaml
122 | output_dir=output/${dataset_name}/${store_model_name}/llava_prompt/
123 | echo "========CLEVR REF Result==========="
124 | python3 eval/eval_tools/refcoco.py --mistral --path ${output_dir}/clevr_ref.json
125 | }
126 |
127 | mmbench
128 | seed
129 | embspatial
130 | clevr
131 | wino
132 | vsr
133 | pope
134 | vstar
135 | refcocog_val
136 | refcocog_test
137 | clevr_ref
138 | gqa
139 | amber
140 | }
141 | run_all
--------------------------------------------------------------------------------
/eval/eval_tools/clevr.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | from utils.util import instantiate_from_config
3 | import argparse, json
4 | from collections import defaultdict
5 |
6 | label2index = {'A':0, 'B': 1, 'C': 2, 'D': 3}
7 |
8 | def main():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--result', type=str, default=None)
11 | parser.add_argument('--config', type=str, default='/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/config/datasets/eval/CLEVR_val1k_opt_context.yaml')
12 | args = parser.parse_args()
13 |
14 | cfg = OmegaConf.load(args.config)
15 | ds = instantiate_from_config(cfg[0])
16 | res = json.load(open(args.result))
17 | class2res = defaultdict(list)
18 | if isinstance(res[0]['predict'], int):
19 | print('evaluating options')
20 | acc = 0
21 | for item in res:
22 | index = int(item['item_id'].split('_')[-1])
23 | q_type = ds.question_type(index)
24 | if item['predict'] == ds.getlabel(index):
25 | c = 1
26 | acc += 1
27 | else:
28 | c = 0
29 | class2res[q_type].append(c)
30 | print('General accuracy: {:.4f} from {} samples'.format(acc/len(res), len(res)))
31 | for k,v in class2res.items():
32 | print('{} Accuracy: {:.4f} from {} samples'.format(k, sum(v)/len(v), len(v)))
33 | return
34 |
35 | cfg = OmegaConf.load(args.config)
36 | ds = instantiate_from_config(cfg[0])
37 |
38 |
39 | item2logit = defaultdict(list)
40 | item2answer = {}
41 |
42 | for item in res:
43 | item_id = int(item['item_id'].split('_')[-1])
44 | question_id, option = ds.get_index(item_id)
45 | label = ds.getlabel(item_id)
46 | if question_id in item2answer:
47 | assert label == item2answer[question_id]
48 | else:
49 | item2answer[question_id] = label
50 | item2logit[question_id].append([item['logit'], option])
51 |
52 | acc = 0
53 | for k in item2logit:
54 | preds = sorted(item2logit[k], key=lambda x:x[0])[0][1]
55 | print(preds, item2answer[k])
56 | if preds == item2answer[k]:
57 | acc += 1
58 | print('accuracy: {}'.format(acc/len(item2logit)))
59 |
60 | if __name__=='__main__':
61 | main()
--------------------------------------------------------------------------------
/eval/eval_tools/convert_res_to_amber.py:
--------------------------------------------------------------------------------
1 | import json,argparse
2 | from utils.eval_util import extract_box, remove_all_box_str
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--src', type=str)
6 | parser.add_argument('--tgt', type=str)
7 | parser.add_argument('--desc', action='store_true')
8 | args = parser.parse_args()
9 |
10 | res = json.load(open(args.src))
11 | new_res = []
12 | if args.desc:
13 | for item in res:
14 | key = 'prediction' if 'prediction' in item else 'predict'
15 | pred = remove_all_box_str(item[key], mistral=True).replace(' ', ' ').replace('', '').strip().replace(' .', '.')
16 | new_res.append({
17 | 'id': item['label'],
18 | 'response': pred
19 | })
20 | else:
21 | for item in res:
22 | tmp = item['prediction'].replace('', '').strip().lower()
23 | if 'yes' in tmp:
24 | pred = 'Yes'
25 | else:
26 | pred = 'No'
27 | new_res.append({
28 | 'id': item['label'],
29 | 'response': pred
30 | })
31 | with open(args.tgt, 'w') as wf:
32 | json.dump(new_res, wf)
--------------------------------------------------------------------------------
/eval/eval_tools/convert_res_to_gqa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | from omegaconf import OmegaConf
5 | from utils.util import instantiate_from_config
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--src", type=str)
9 | parser.add_argument("--dst", type=str)
10 | args = parser.parse_args()
11 | cfg = OmegaConf.load('/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/config/datasets/eval/GQA.yaml')
12 | ds = instantiate_from_config(cfg[0])
13 |
14 | all_answers = []
15 | res = json.load(open(args.src))
16 | for line in res:
17 | index = int(line['item_id'].split('_')[-1])
18 | question_id = ds.keys[index]
19 | text = line['prediction'].replace('', '').rstrip('.').strip().lower()
20 | all_answers.append({"questionId": question_id, "prediction": text})
21 |
22 | with open(args.dst, 'w') as f:
23 | json.dump(all_answers, f)
--------------------------------------------------------------------------------
/eval/eval_tools/convert_res_to_gqa_llava.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("--src", type=str)
7 | parser.add_argument("--dst", type=str)
8 | args = parser.parse_args()
9 |
10 | all_answers = []
11 | for line_idx, line in enumerate(open(args.src)):
12 | res = json.loads(line)
13 | question_id = res['question_id']
14 | text = res['text'].rstrip('.').lower()
15 | all_answers.append({"questionId": question_id, "prediction": text})
16 |
17 | with open(args.dst, 'w') as f:
18 | json.dump(all_answers, f)
--------------------------------------------------------------------------------
/eval/eval_tools/convert_res_to_mmbench.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import pandas as pd
5 | from omegaconf import OmegaConf
6 |
7 | def get_args():
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--annotation-file", type=str, required=True)
10 | parser.add_argument("--src", type=str, required=True)
11 | parser.add_argument("--tgt", type=str, required=True)
12 |
13 | return parser.parse_args()
14 |
15 | if __name__ == "__main__":
16 | args = get_args()
17 |
18 | df = pd.read_table(args.annotation_file)
19 |
20 | all_answers = []
21 | res = json.load(open(args.src))
22 | cur_df = df.copy()
23 | cur_df = cur_df.drop(columns=['hint', 'category', 'source', 'image', 'comment', 'l2-category'])
24 | cur_df.insert(6, 'prediction', None)
25 | for pred in res:
26 | cur_df.loc[df['index'] == pred['dataset_id'], 'prediction'] = pred['prediction'].replace('', '')
27 |
28 | cur_df.to_excel(args.tgt, index=False, engine='openpyxl')
--------------------------------------------------------------------------------
/eval/eval_tools/mmbench.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | from utils.util import instantiate_from_config
3 | import argparse, json
4 | from collections import defaultdict
5 | import pandas as pd
6 |
7 | label2index = {'A':0, 'B': 1, 'C': 2, 'D': 3}
8 |
9 | def main():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--result', type=str, default=None)
12 | parser.add_argument('--config', type=str, default='/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/config/datasets/eval/MMBench_DEV_opt.yaml')
13 | args = parser.parse_args()
14 |
15 | cfg = OmegaConf.load(args.config)
16 | mmbench_meta = pd.read_table(cfg[0]['params']['path'])
17 | res = json.load(open(args.result))
18 | for item in res:
19 | if 'predict' not in item:
20 | item['predict'] = item['prediction']
21 | class2res = defaultdict(list)
22 | if isinstance(res[0]['predict'], int):
23 | acc = 0
24 | for item in res:
25 | index = int(item['item_id'].split('_')[-1])
26 | q_type = mmbench_meta.iloc[index]['category']
27 | if item['predict'] == '':
28 | item['predict'] = 0
29 | if item['predict'] == label2index[item['label']]:
30 | c = 1
31 | acc += 1
32 | else:
33 | c = 0
34 | class2res[q_type].append(c)
35 | print('General accuracy: {}'.format(acc/len(res)))
36 | for k,v in class2res.items():
37 | print('{} Accuracy: {}'.format(k, sum(v)/len(v)))
38 | return
39 |
40 | # cfg = OmegaConf.load(args.config)
41 | # ds = instantiate_from_config(cfg[0])
42 |
43 |
44 | # item2logit = defaultdict(list)
45 | # item2answer = {}
46 |
47 | # for item in res:
48 | # item_id = int(item['item_id'].split('_')[-1])
49 | # question_id, option = ds.get_index(item_id)
50 | # label = ds.getlabel(item_id)
51 | # if question_id in item2answer:
52 | # assert label == item2answer[question_id]
53 | # else:
54 | # item2answer[question_id] = label
55 | # item2logit[question_id].append([item['logit'], option])
56 |
57 | # acc = 0
58 | # for k in item2logit:
59 | # preds = sorted(item2logit[k], key=lambda x:x[0])[0][1]
60 | # print(preds, item2answer[k])
61 | # if preds == item2answer[k]:
62 | # acc += 1
63 | # print('accuracy: {}'.format(acc/len(item2logit)))
64 |
65 | if __name__=='__main__':
66 | main()
--------------------------------------------------------------------------------
/eval/eval_tools/mp3d.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | from utils.util import instantiate_from_config
3 | import argparse, json
4 | from collections import defaultdict
5 |
6 | label2index = {'A':0, 'B': 1, 'C': 2, 'D': 3}
7 |
8 | def main():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--result', type=str, default=None)
11 | parser.add_argument('--config', type=str, default='/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/config/datasets/eval/EmbSpa_dev_opt.yaml')
12 | args = parser.parse_args()
13 |
14 | cfg = OmegaConf.load(args.config)
15 | meta = json.load(open(cfg[0]['params']['path']))
16 | # seed_question = [item for item in seed_meta['questions'] if item['data_type']=='image']
17 | # question_types = {v:k for k,v in seed_meta['question_type'].items()}
18 | res = json.load(open(args.result))
19 | class2res = defaultdict(list)
20 | if isinstance(res[0]['predict'], int):
21 | print('evaluating options')
22 | acc = 0
23 | for item in res:
24 | index = int(item['item_id'].split('_')[-1])
25 | item['label'] = meta[index]['answer']
26 | # q_type = question_types[seed_question[index]['question_type_id']]
27 | if item['predict'] == item['label']:
28 | c = 1
29 | acc += 1
30 | else:
31 | c = 0
32 | # class2res[q_type].append(c)
33 | print('General accuracy: {:.4f}'.format(acc/len(res)))
34 | # for k,v in class2res.items():
35 | # print('{} Accuracy: {:.4f}'.format(k, sum(v)/len(v)))
36 | return
37 |
38 | cfg = OmegaConf.load(args.config)
39 | ds = instantiate_from_config(cfg[0])
40 |
41 |
42 | item2logit = defaultdict(list)
43 | item2answer = {}
44 |
45 | for item in res:
46 | item_id = int(item['item_id'].split('_')[-1])
47 | question_id, option = ds.get_index(item_id)
48 | label = ds.getlabel(item_id)
49 | if question_id in item2answer:
50 | assert label == item2answer[question_id]
51 | else:
52 | item2answer[question_id] = label
53 | item2logit[question_id].append([item['logit'], option])
54 |
55 | acc = 0
56 | for k in item2logit:
57 | preds = sorted(item2logit[k], key=lambda x:x[0])[0][1]
58 | print(preds, item2answer[k])
59 | if preds == item2answer[k]:
60 | acc += 1
61 | print('accuracy: {}'.format(acc/len(item2logit)))
62 |
63 | if __name__=='__main__':
64 | main()
--------------------------------------------------------------------------------
/eval/eval_tools/pope.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--data')
6 | args = parser.parse_args()
7 | res = json.load(open(args.data))
8 |
9 | invalid = correct = 0
10 | for item in res:
11 | key = 'predict' if 'predict' in item else 'prediction'
12 | pred = item[key].replace('', '').strip().lower()
13 | if pred not in ['yes', 'no']:
14 | if pred.startswith('yes'):
15 | p = 'yes'
16 | elif pred.startswith('no'):
17 | p = 'no'
18 | else:
19 | invalid += 1
20 | p = 'no'
21 | # print(pred)
22 | # invalid += 1
23 | # correct += 1 if item['label'] == 'no' else 0
24 | else:
25 | p = pred
26 | correct += 1 if p==item['label'] else 0
27 |
28 | print('accuracy: {}, invalid rate: {}'.format(correct / len(res), invalid/len(res)))
--------------------------------------------------------------------------------
/eval/eval_tools/refcoco.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | from utils.eval_util import *
4 | import tqdm
5 | from functools import partial
6 |
7 | def main():
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--path', type=str, default=None)
10 | parser.add_argument('--method', type=str, default='str')
11 | parser.add_argument('--mistral', action='store_true')
12 | args = parser.parse_args()
13 |
14 | result = json.load(open(args.path))
15 | if args.method == 'str':
16 | extract_fn = partial(extract_box_str, mistral=args.mistral)
17 | elif args.method == 'llava':
18 | extract_fn = extract_box_str_llava
19 | elif args.method == 'llava16':
20 | extract_fn = partial(extract_box_str_llava16, mistral=args.mistral)
21 | elif args.method == 'space':
22 | extract_fn = extract_box_str_space
23 | elif args.method == 'special_tokens':
24 | extract_fn = extract_box_str
25 | elif args.method == 'qwenvl':
26 | extract_fn = extract_box_str_qwenvl
27 | elif args.method == 'minigptv2':
28 | extract_fn = extract_box_str_minigptv2
29 | else:
30 | raise NotImplementedError
31 | samples = {
32 | 'fail': 0,
33 | 'wrong': 0,
34 | 'correct': 0
35 | }
36 | key = 'prediction' if 'prediction' in result[0] else 'predict'
37 | for item in result:
38 | # print(item)
39 | box = extract_fn(item[key][0] if isinstance(item[key], list) else item[key])
40 | if box is None:
41 | print(item[key])
42 | samples['fail'] += 1
43 | else:
44 | iou = cal_iou(box, item['label'])
45 | if iou >= 0.5:
46 | samples['correct'] += 1
47 | else:
48 | samples['wrong'] += 1
49 | print(json.dumps(samples))
50 | print('accuracy: {}'.format(samples['correct'] / len(result)))
51 |
52 | if __name__=='__main__':
53 | main()
--------------------------------------------------------------------------------
/eval/eval_tools/seed.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | from utils.util import instantiate_from_config
3 | import argparse, json
4 | from collections import defaultdict
5 |
6 | label2index = {'A':0, 'B': 1, 'C': 2, 'D': 3}
7 |
8 | def main():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--result', type=str, default=None)
11 | parser.add_argument('--config', type=str, default='/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/config/datasets/eval/SEED_opt.yaml')
12 | parser.add_argument('--no_counting', action='store_true')
13 | args = parser.parse_args()
14 |
15 | cfg = OmegaConf.load(args.config)
16 | seed_meta = json.load(open(cfg[0]['params']['path']))
17 | seed_question = [item for item in seed_meta['questions'] if item['data_type']=='image']
18 | question_types = {v:k for k,v in seed_meta['question_type'].items()}
19 | res = json.load(open(args.result))
20 | class2res = defaultdict(list)
21 | if isinstance(res[0]['predict'], int):
22 | print('evaluating options')
23 | acc = 0
24 | for item in res:
25 | index = int(item['item_id'].split('_')[-1])
26 | q_type = question_types[seed_question[index]['question_type_id']]
27 | if item['predict'] == '':
28 | item['predict'] = 0
29 | if item['predict'] == label2index[item['label']]:
30 | c = 1
31 | acc += 1
32 | else:
33 | c = 0
34 | class2res[q_type].append(c)
35 | if args.no_counting:
36 | all_samples = 0
37 | corr_samples = 0
38 | for k,v in class2res.items():
39 | if k!= 'Instances Counting':
40 | all_samples += len(v)
41 | corr_samples += sum(v)
42 | print('General accuracy w/o counting: {:.4f}'.format(corr_samples/all_samples))
43 | print('General accuracy: {:.4f}'.format(acc/len(res)))
44 | else:
45 | print('General accuracy: {:.4f}'.format(acc/len(res)))
46 | for k,v in class2res.items():
47 | print('{} Accuracy: {:.4f}'.format(k, sum(v)/len(v)))
48 | return
49 |
50 | cfg = OmegaConf.load(args.config)
51 | ds = instantiate_from_config(cfg[0])
52 |
53 |
54 | item2logit = defaultdict(list)
55 | item2answer = {}
56 |
57 | for item in res:
58 | item_id = int(item['item_id'].split('_')[-1])
59 | question_id, option = ds.get_index(item_id)
60 | label = ds.getlabel(item_id)
61 | if question_id in item2answer:
62 | assert label == item2answer[question_id]
63 | else:
64 | item2answer[question_id] = label
65 | item2logit[question_id].append([item['logit'], option])
66 |
67 | acc = 0
68 | for k in item2logit:
69 | preds = sorted(item2logit[k], key=lambda x:x[0])[0][1]
70 | print(preds, item2answer[k])
71 | if preds == item2answer[k]:
72 | acc += 1
73 | print('accuracy: {}'.format(acc/len(item2logit)))
74 |
75 | if __name__=='__main__':
76 | main()
--------------------------------------------------------------------------------
/eval/eval_tools/vsr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from collections import defaultdict
3 | import json
4 | from omegaconf import OmegaConf
5 | from utils.util import instantiate_from_config
6 |
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--data')
9 | parser.add_argument('--config')
10 |
11 | args = parser.parse_args()
12 | res = json.load(open(args.data))
13 | cfg = OmegaConf.load(args.config)
14 | ds = instantiate_from_config(cfg[0], reload=True)
15 |
16 | invalid = correct = 0
17 | relation2res = defaultdict(list)
18 | if 'predict' in res[0] and isinstance(res[0]['predict'], int):
19 | for item in res:
20 | item_id = int(item['item_id'].split('_')[-1])
21 | relation = ds.meta[item_id]['relation']
22 | pred = item['predict']
23 | correct += 1 if pred==item['label'] else 0
24 | relation2res[relation].append(1 if pred==item['label'] else 0)
25 | else:
26 | for item in res:
27 | item_id = int(item['item_id'].split('_')[-1])
28 | relation = ds.meta[item_id]['relation']
29 | key = 'predict' if 'predict' in item else 'prediction'
30 | pred = item[key].replace('', '').strip().lower()
31 | if pred not in ['yes', 'no']:
32 | # p = -1
33 | if pred.startswith('yes'):
34 | p = 1
35 | elif pred.startswith('no'):
36 | p = 0
37 | else:
38 | invalid += 1
39 | p = 0
40 | else:
41 | if pred == 'yes':
42 | p = 1
43 | else:
44 | p = 0
45 | correct += 1 if p==item['label'] else 0
46 | relation2res[relation].append(1 if p==item['label'] else 0)
47 |
48 | print('accuracy: {}, invalid rate: {}'.format(correct / len(res), invalid/len(res)))
49 | print('====results in detail=====')
50 | for k in sorted(relation2res.keys()):
51 | v = relation2res[k]
52 | print('{}: {} in {} samples'.format(k, round(sum(v)/len(v),4), len(v)))
--------------------------------------------------------------------------------
/eval/eval_tools/vstar.py:
--------------------------------------------------------------------------------
1 | from omegaconf import OmegaConf
2 | from utils.util import instantiate_from_config
3 | import argparse, json
4 | from collections import defaultdict
5 |
6 | label2index = {'A':0, 'B': 1, 'C': 2, 'D': 3}
7 |
8 | def main():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--result', type=str, default=None)
11 | parser.add_argument('--config', type=str, default='/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/config/datasets/eval/SEED_opt.yaml')
12 | args = parser.parse_args()
13 |
14 | # cfg = OmegaConf.load(args.config)
15 | # seed_meta = json.load(open(cfg[0]['params']['path']))
16 | # seed_question = [item for item in seed_meta['questions'] if item['data_type']=='image']
17 | # question_types = {v:k for k,v in seed_meta['question_type'].items()}
18 | res = json.load(open(args.result))
19 | class2res = defaultdict(list)
20 | if isinstance(res[0]['predict'], int) or res[0]['predict']=='':
21 | print('evaluating options')
22 | acc = 0
23 | for item in res:
24 | index = int(item['item_id'].split('_')[-1])
25 | # q_type = question_types[seed_question[index]['question_type_id']]
26 | if item['predict'] == item['label']:
27 | c = 1
28 | acc += 1
29 | else:
30 | c = 0
31 | # class2res[q_type].append(c)
32 | print('General accuracy: {:.4f}'.format(acc/len(res)))
33 | # for k,v in class2res.items():
34 | # print('{} Accuracy: {:.4f}'.format(k, sum(v)/len(v)))
35 | return
36 |
37 | cfg = OmegaConf.load(args.config)
38 | ds = instantiate_from_config(cfg[0])
39 |
40 |
41 | item2logit = defaultdict(list)
42 | item2answer = {}
43 |
44 | for item in res:
45 | item_id = int(item['item_id'].split('_')[-1])
46 | question_id, option = ds.get_index(item_id)
47 | label = ds.getlabel(item_id)
48 | if question_id in item2answer:
49 | assert label == item2answer[question_id]
50 | else:
51 | item2answer[question_id] = label
52 | item2logit[question_id].append([item['logit'], option])
53 |
54 | acc = 0
55 | for k in item2logit:
56 | preds = sorted(item2logit[k], key=lambda x:x[0])[0][1]
57 | print(preds, item2answer[k])
58 | if preds == item2answer[k]:
59 | acc += 1
60 | print('accuracy: {}'.format(acc/len(item2logit)))
61 |
62 | if __name__=='__main__':
63 | main()
--------------------------------------------------------------------------------
/eval/merge_benchmark.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import os
3 | from utils.logger import setup_logger
4 | import json
5 | import torch
6 |
7 | def rank0_print(args, res):
8 | if args.local_rank==0 or args.local_rank == -1:
9 | print(res)
10 |
11 | def get_output_name(args, mid_output=True):
12 | if mid_output:
13 | return os.path.join(args.output_dir,
14 | '{}_rank{}.json'.format(args.dataset_name, args.local_rank))
15 | else:
16 | return os.path.join(args.output_dir,
17 | '{}.json'.format(args.dataset_name))
18 |
19 | def get_all_output_names(args):
20 | return [os.path.join(args.output_dir,
21 | '{}_rank{}.json'.format(args.dataset_name, r)) for r in range(args.n_gpus)]
22 |
23 |
24 |
25 | def main():
26 | parser = ArgumentParser()
27 | parser.add_argument('--config_arg', type=str, default=None)
28 | old_args = parser.parse_args()
29 |
30 | args = torch.load(old_args.config_arg)
31 | print(args)
32 |
33 | base_config_name = os.path.basename(args.eval_data)
34 | args.dataset_name = base_config_name[:-5] if base_config_name.endswith('.yaml') else base_config_name
35 |
36 |
37 | full_res = []
38 | for fn in get_all_output_names(args):
39 | full_res.extend(json.load(open(fn, 'r')))
40 | os.remove(fn)
41 | with open(get_output_name(args, mid_output=False), 'w') as wf:
42 | json.dump(full_res, wf)
43 | # saving the arguments
44 | torch.save(args, get_output_name(args, mid_output=False)[:-4]+'args.bin')
45 |
46 |
47 | if __name__=='__main__':
48 | main()
--------------------------------------------------------------------------------
/figs/model_arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RupertLuo/VoCoT/d8177ca2e8e303f7d35609a000e5eff7bec319cb/figs/model_arch.png
--------------------------------------------------------------------------------
/figs/sample_input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RupertLuo/VoCoT/d8177ca2e8e303f7d35609a000e5eff7bec319cb/figs/sample_input.jpg
--------------------------------------------------------------------------------
/locals/datasets/image_caption/cc3m.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import json as js
3 | import math
4 | import random
5 | import os
6 | from PIL import Image
7 | from constants import *
8 |
9 | class FilteredCC3MI2TDataset(Dataset):
10 |
11 | def __init__(self,
12 | path: str,
13 | image_folder: str,
14 | meta_folder: str=None,
15 | instruct: bool = False,
16 | min_resize_res: int = 256,
17 | max_resize_res: int = 256,
18 | crop_res: int = 256,
19 | flip_prob: float = 0.5,
20 | sample_weight: float = 1.0,
21 | check: bool = False,
22 | output_mode: str = 'text',
23 | shuffle: bool = False,
24 | raw_image: bool = False,
25 | inference: bool = False,
26 | **kwargs):
27 | # load from json ../datasets/gqa-inpaint/meta_info.json
28 | self.path = path
29 | self.instruct = instruct
30 | self.inference = inference
31 | if path.endswith('txt'):
32 | self.meta = []
33 | with open(path) as rf:
34 | for line in rf:
35 | self.meta.append(line.strip())
36 | self.need_reload = True
37 | else:
38 | self.need_reload = False
39 | self.meta = js.load(open(path))
40 | if meta_folder is not None:
41 | self.meta_folder = os.path.join(os.path.dirname(path), 'meta')
42 | else:
43 | self.meta_folder = meta_folder
44 | self.image_folder = image_folder
45 | self.min_resize_res = min_resize_res
46 | self.max_resize_res = max_resize_res
47 | self.crop_res = crop_res
48 | self.check = check
49 | self.raw_image = raw_image
50 | self.output_mode = output_mode
51 | self.shuffle = shuffle
52 |
53 | self.flip_prob = flip_prob
54 | self.sample_weight = sample_weight
55 | self.generation_prompts = [
56 | "generate image with caption:",
57 | "can you give me the image with caption:",
58 | "help me to generate this image:",
59 | "generate image with according to caption:",
60 | "according to caption, generate image:",
61 | "an image with caption:",
62 | "can you visualize this caption:",
63 | ]
64 | print(f"CC3MDataset has {len(self)} samples!!")
65 |
66 | def __len__(self):
67 | return int(len(self.meta) * self.sample_weight)
68 |
69 | def __getitem__(self, i):
70 |
71 | if self.sample_weight >= 1:
72 | i = i % len(self.meta)
73 | else:
74 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight))
75 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder)
76 |
77 | if self.need_reload:
78 | meta_fn = self.meta[i]
79 | item = js.load(open(os.path.join(self.meta_folder, meta_fn)))
80 | else:
81 | item = self.meta[i]
82 | # item = self.meta[i]
83 | tgt_img = Image.open(os.path.join(self.image_folder,item['image'])).convert('RGB')
84 | instruction = item['conversations'][1]['value']
85 | # return image_0, image_1, instruction
86 |
87 | if self.output_mode == 'conversation':
88 | sources = [{'from': 'human', 'value': random.choice(self.generation_prompts)+instruction},
89 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}]
90 | return {'output_images': [tgt_img], 'conversation': sources, 'image_label_masks': [1]}
91 | elif self.output_mode == 'text':
92 | if self.shuffle:
93 | prob = random.random()
94 | if prob > 0.5:
95 | text = "{} {}".format(ALL_IMG_TOKENS_STR, instruction)
96 | image_label_masks = [0]
97 | else:
98 | text = "{} {}".format(instruction, ALL_IMG_TOKENS_STR)
99 | image_label_masks = [1]
100 | else:
101 | text = "{} {}".format(ALL_IMG_TOKENS_STR, instruction)
102 | image_label_masks = [0]
103 | if self.inference:
104 | text = ALL_IMG_TOKENS_STR
105 | return {'input_images': [tgt_img], 'text': text, 'image_label_masks': image_label_masks}
--------------------------------------------------------------------------------
/locals/datasets/image_caption/coco.py:
--------------------------------------------------------------------------------
1 | from email.policy import default
2 | from torch.utils.data import Dataset
3 | import json as js
4 | import math
5 | import random
6 | import os
7 | from PIL import Image
8 | from collections import defaultdict
9 | from constants import *
10 |
11 | class COCOI2TDataset(Dataset):
12 |
13 | def __init__(self,
14 | path: str,
15 | image_folder: str,
16 | instruct: bool = False,
17 | min_resize_res: int = 256,
18 | max_resize_res: int = 256,
19 | crop_res: int = 256,
20 | flip_prob: float = 0.5,
21 | sample_weight: float = 1.0,
22 | check: bool = False,
23 | output_mode: str = 'text',
24 | raw_image: bool = False,
25 | sample_mode: str='text',
26 | shuffle: bool=False,
27 | inference: bool=False,
28 | **kwargs):
29 | # load from json ../datasets/gqa-inpaint/meta_info.json
30 | self.path = path
31 | self.instruct = instruct
32 | self.inference = inference
33 | self.meta = js.load(open(path))
34 | if isinstance(self.meta, dict):
35 | self.meta = self.meta['annotations']
36 | self.sample_mode = sample_mode
37 | if self.sample_mode == 'image':
38 | tmp = defaultdict(list)
39 | for item in self.meta:
40 | tmp[item['image_id']].append(item['caption'])
41 | self.meta = [{'image_id': k, 'captions': v} for k,v in tmp.items()]
42 | self.image_folder = image_folder
43 | self.min_resize_res = min_resize_res
44 | self.max_resize_res = max_resize_res
45 | self.crop_res = crop_res
46 | self.check = check
47 | self.raw_image = raw_image
48 | self.output_mode = output_mode
49 | self.shuffle = shuffle # shuffle is for interleaved image-text data
50 |
51 | self.flip_prob = flip_prob
52 | self.sample_weight = sample_weight
53 | self.generation_prompts = [
54 | "generate image with caption:",
55 | "can you give me the image with caption:",
56 | "help me to generate this image:",
57 | "generate image with according to caption:",
58 | "according to caption, generate image:",
59 | "an image with caption:",
60 | "can you visualize this caption:",
61 | ]
62 | print(f"COCO has {len(self)} samples!!")
63 |
64 | def __len__(self):
65 | return int(len(self.meta) * self.sample_weight)
66 |
67 | def get_image_fn(self, image_id):
68 | return os.path.join(self.image_folder, 'train2014', 'COCO_train2014_{:012d}.jpg'.format(image_id))
69 |
70 | def __getitem__(self, i):
71 |
72 | if self.sample_weight >= 1:
73 | i = i % len(self.meta)
74 | else:
75 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight))
76 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder)
77 |
78 | item = self.meta[i]
79 | tgt_img = Image.open(self.get_image_fn(item['image_id'])).convert('RGB')
80 | if self.sample_mode == 'text':
81 | instruction = item['caption']
82 | elif self.sample_mode == 'image':
83 | instruction = random.choice(item['captions'])
84 | else:
85 | raise ValueError
86 | # return image_0, image_1, instruction
87 |
88 | if self.output_mode == 'conversation':
89 | sources = [{'from': 'human', 'value': random.choice(self.generation_prompts)+instruction},
90 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}]
91 | return {'output_images': [tgt_img], 'conversation': sources, 'image_label_masks': [1]}
92 | elif self.output_mode == 'text':
93 | if self.shuffle:
94 | prob = random.random()
95 | if prob > 0.5:
96 | text = "{} {}".format(ALL_IMG_TOKENS_STR, instruction)
97 | image_label_masks = [0]
98 | else:
99 | text = "{} {}".format(instruction, ALL_IMG_TOKENS_STR)
100 | image_label_masks = [1]
101 | else:
102 | text = "{} {}".format(ALL_IMG_TOKENS_STR, instruction)
103 | image_label_masks = [0]
104 | if self.inference:
105 | text = ALL_IMG_TOKENS_STR
106 | return {'input_images': [tgt_img], 'text': text, 'image_label_masks': image_label_masks}
--------------------------------------------------------------------------------
/locals/datasets/image_caption/flicker30k.py:
--------------------------------------------------------------------------------
1 | from email.policy import default
2 | from torch.utils.data import Dataset
3 | import json as js
4 | import math
5 | import random
6 | import os
7 | from PIL import Image
8 | from collections import defaultdict
9 | from constants import *
10 | from ..utils.box_utils import box2str, reshape_box
11 |
12 | class FlickrDataset(Dataset):
13 |
14 | def __init__(self,
15 | path: str,
16 | image_folder: str,
17 | instruct: bool = False,
18 | sample_weight: float = 1.0,
19 | check: bool = False,
20 | output_mode: str = 'text',
21 | raw_image: bool = False,
22 | shuffle: bool=False,
23 | avoid_image_gen: bool=False,
24 | phrase_format: str='special_tokens',
25 | phrase_prec: int=2,
26 | expand2square: bool=False,
27 | object_format: str='image',
28 | phrase_space: bool=False,
29 | **kwargs):
30 | # load from json ../datasets/gqa-inpaint/meta_info.json
31 | self.path = path
32 | self.instruct = instruct
33 | self.meta = [js.loads(line) for line in open(self.path)]
34 | self.image_folder = image_folder
35 | self.raw_image = raw_image
36 | self.object_format = object_format
37 | self.output_mode = output_mode
38 | self.avoid_image_gen = avoid_image_gen
39 | self.phrase_format = phrase_format
40 | self.phrase_prec = phrase_prec
41 | self.shuffle = shuffle # shuffle is for interleaved image-text data
42 | self.expand2square = expand2square
43 | self.phrase_space = phrase_space
44 |
45 | self.sample_weight = sample_weight
46 | print(f"Flickr30k entities has {len(self)} samples!!")
47 |
48 | def __len__(self):
49 | return int(len(self.meta) * self.sample_weight)
50 |
51 | def proc_box(self, box, image):
52 | w, h = image.size
53 | x_min, y_min, x_max, y_max = box
54 | sub_image = image.crop((x_min, y_min, x_max, y_max))
55 | new_box = [c / w if (i%2==0) else c / h for i,c in enumerate(box)]
56 | if self.expand2square:
57 | new_box = reshape_box(image, new_box)
58 | return new_box, sub_image
59 |
60 | def __len__(self):
61 | return len(self.meta)
62 |
63 |
64 | def __getitem__(self, i):
65 |
66 | if self.sample_weight >= 1:
67 | i = i % len(self.meta)
68 | else:
69 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight))
70 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder)
71 |
72 | item = self.meta[i]
73 | image_id = item['image_id']
74 | image = Image.open(os.path.join(self.image_folder, '{}.jpg'.format(image_id))).convert('RGB')
75 | caption = item['sentence']
76 |
77 | box_infos = []
78 | sub_image_infos = []
79 | for b in item['boxes']:
80 | current_reshape_infos = self.proc_box(b, image)
81 | box_infos.append(current_reshape_infos[0])
82 | sub_image_infos.append(current_reshape_infos[1])
83 | all_box_input = [[0.0, 0.0, 1.0, 1.0]]
84 | image_label_masks = [0]
85 | sub_image = []
86 | input_images = [image]
87 |
88 | # start from
89 | input_query = item['sentence'].replace('', '').split('')
90 | refined_query = ALL_IMG_TOKENS_STR + DEFAULT_GRD_TOKEN + ' '
91 | for box_ind, current_box in enumerate(item['boxes_seq']):
92 | sub_image.extend(sub_image_infos[c] for c in current_box)
93 | current_box = [box_infos[c] for c in current_box]
94 | all_box_input.extend(current_box)
95 | image_label_masks.extend([0]*len(current_box))
96 | if self.object_format == 'coordinate':
97 | box_in_str = DEFAULT_SEP_TOKEN.join([DEFAULT_BOC_TOKEN + box2str(c, mode=self.phrase_format, prec=self.phrase_prec, space=self.phrase_space) + DEFAULT_EOC_TOKEN for c in current_box])
98 | else:
99 | box_in_str = DEFAULT_SEP_TOKEN.join([DEFAULT_BOC_TOKEN + box2str(c, mode=self.phrase_format, prec=self.phrase_prec, space=self.phrase_space) + DEFAULT_EOC_TOKEN + ALL_IMG_TOKENS_STR for c in current_box])
100 | refined_query = refined_query + input_query[box_ind] + box_in_str
101 | refined_query = refined_query + input_query[-1]
102 |
103 | if self.output_mode == 'conversation':
104 | query = ALL_IMG_TOKENS_STR + DEFAULT_GRD_TOKEN + '\n' + 'Briefly describe this image. Locate objects and provide bounding boxes in your response.'
105 | response = refined_query.replace(ALL_IMG_TOKENS_STR + DEFAULT_GRD_TOKEN + ' ', '')
106 | conversation = [{'from': 'human', 'value': query}, {'from': 'gpt', 'value': response}]
107 | return {'input_images': input_images, 'conversation': conversation, 'image_label_masks': image_label_masks, 'box': all_box_input}
108 | raise ValueError
109 | elif self.output_mode == 'text':
110 | # print(instruction)
111 | if self.object_format == 'image':
112 | return {'input_images': input_images + sub_image, 'text': refined_query, 'image_label_masks': image_label_masks}
113 | elif self.object_format == 'representation':
114 | return {'input_images': input_images,
115 | 'text': refined_query,
116 | 'image_label_masks': image_label_masks,
117 | 'box': all_box_input}
118 | elif self.object_format == 'coordinate':
119 | return {'input_images': input_images,
120 | 'text': refined_query,
121 | 'image_label_masks': image_label_masks,
122 | 'box': all_box_input[:1]}
123 | else:
124 | raise ValueError
--------------------------------------------------------------------------------
/locals/datasets/image_caption/fusecap.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import json as js
3 | import math
4 | import random
5 | import os
6 | from PIL import Image
7 | from constants import *
8 | import tqdm
9 |
10 | class FuseCapCC3MI2TDataset(Dataset):
11 |
12 | def __init__(self,
13 | path: str,
14 | image_folder: str,
15 | instruct: bool = False,
16 | min_resize_res: int = 256,
17 | max_resize_res: int = 256,
18 | crop_res: int = 256,
19 | flip_prob: float = 0.5,
20 | sample_weight: float = 1.0,
21 | check: bool = False,
22 | output_mode: str = 'text',
23 | raw_image: bool = False,
24 | **kwargs):
25 | # load from json ../datasets/gqa-inpaint/meta_info.json
26 | self.path = path
27 | self.instruct = instruct
28 | self.meta = js.load(open(path))
29 | self.image_folder = image_folder
30 | self.min_resize_res = min_resize_res
31 | self.max_resize_res = max_resize_res
32 | self.crop_res = crop_res
33 | self.check = check
34 | self.raw_image = raw_image
35 | self.output_mode = output_mode
36 |
37 | # check the image existence
38 | valid_items = []
39 | for item in tqdm.tqdm(self.meta):
40 | if os.path.exists(self.get_image_name(item['image_id'])):
41 | valid_items.append(item)
42 | self.meta = valid_items
43 |
44 | self.flip_prob = flip_prob
45 | self.sample_weight = sample_weight
46 | self.generation_prompts = [
47 | "generate image with caption:",
48 | "can you give me the image with caption:",
49 | "help me to generate this image:",
50 | "generate image with according to caption:",
51 | "according to caption, generate image:",
52 | "an image with caption:",
53 | "can you visualize this caption:",
54 | ]
55 | print(f"CC3MDataset has {len(self)} samples!!")
56 |
57 | def __len__(self):
58 | return int(len(self.meta) * self.sample_weight)
59 |
60 | def get_image_name(self, image_id):
61 | return os.path.join(self.image_folder, 'GCC_train_{:09d}.jpg'.format(image_id))
62 |
63 | def __getitem__(self, i):
64 |
65 | if self.sample_weight >= 1:
66 | i = i % len(self.meta)
67 | else:
68 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight))
69 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder)
70 |
71 | item = self.meta[i]
72 | tgt_img = Image.open(os.path.join(self.image_folder,item['image'])).convert('RGB')
73 | instruction = item['conversations'][1]['value']
74 | # return image_0, image_1, instruction
75 |
76 | if self.output_mode == 'conversation':
77 | sources = [{'from': 'human', 'value': random.choice(self.generation_prompts)+instruction},
78 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}]
79 | return {'output_images': [tgt_img], 'conversation': sources, 'image_label_masks': [1]}
80 | elif self.output_mode == 'text':
81 | return {'input_images': [tgt_img], 'text': "{} {}".format(ALL_IMG_TOKENS_STR, instruction), 'image_label_masks': [0]}
82 |
83 | class FuseCapCC12MI2TDataset(FuseCapCC3MI2TDataset):
84 | def get_image_name(self, image_id):
85 | image_fn = '{:08d}.jpg'.format(image_id)
86 | return os.path.join(self.image_folder, '{}'.format(int(image_fn[:2])), image_fn)
--------------------------------------------------------------------------------
/locals/datasets/image_caption/lvis_gpt4v.py:
--------------------------------------------------------------------------------
1 | from email.policy import default
2 | from torch.utils.data import Dataset
3 | import json as js
4 | import math
5 | import random
6 | import os
7 | from PIL import Image
8 | from collections import defaultdict
9 | from constants import *
10 |
11 | def load_prompt(fn):
12 | prompts = []
13 | with open(fn) as f:
14 | line=f.readline()
15 | while line:
16 | line=line.strip('\n')
17 | prompts.append(line)
18 | line=f.readline()
19 | return prompts
20 |
21 | class LvisCapDataset(Dataset):
22 | def __init__(self,
23 | path: str,
24 | image_folder: str,
25 | instruct: bool = False,
26 | min_resize_res: int = 256,
27 | max_resize_res: int = 256,
28 | crop_res: int = 256,
29 | flip_prob: float = 0.5,
30 | sample_weight: float = 1.0,
31 | check: bool = False,
32 | output_mode: str = 'conversation',
33 | raw_image: bool = False,
34 | shuffle: bool=False,
35 | shuffle_prob: float=0.5,
36 | inference: bool=False,
37 | caption_type: str='long',
38 | task_type: str='i2t',
39 | **kwargs):
40 | # load from json ../datasets/gqa-inpaint/meta_info.json
41 | self.path = path
42 | self.instruct = instruct
43 | self.inference = inference
44 | self.meta = js.load(open(path))
45 | self.image_folder = image_folder
46 | self.min_resize_res = min_resize_res
47 | self.max_resize_res = max_resize_res
48 | self.crop_res = crop_res
49 | self.check = check
50 | self.raw_image = raw_image
51 | self.output_mode = output_mode
52 | self.shuffle = shuffle # shuffle is for interleaved image-text data
53 | self.shuffle_prob = shuffle_prob # the probability for shuffle
54 |
55 | self.flip_prob = flip_prob
56 | self.task_type = task_type
57 | self.sample_weight = sample_weight
58 | self.caption_type = caption_type
59 | self.i2t_long_prompts = load_prompt('locals/datasets/prompts/prompt_captioning_long.txt')
60 | self.i2t_short_prompts = load_prompt('locals/datasets/prompts/prompt_captioning_short.txt')
61 | self.t2i_prompts = load_prompt('locals/datasets/prompts/prompt_txt2img.txt')
62 | print(f"LVIS-GPT4V-Captions has {len(self)} samples!!")
63 |
64 | def __len__(self):
65 | return int(len(self.meta) * self.sample_weight)
66 |
67 | def get_image_fn(self, image_id):
68 | return os.path.join(self.image_folder, image_id)
69 |
70 | def __getitem__(self, i):
71 |
72 | if self.sample_weight >= 1:
73 | i = i % len(self.meta)
74 | else:
75 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight))
76 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder)
77 |
78 | item = self.meta[i]
79 | tgt_img = Image.open(self.get_image_fn(item['image'])).convert('RGB')
80 |
81 | # find the caption
82 | if self.caption_type == 'random':
83 | prob = random.random()
84 | if prob > 0.5:
85 | caption_type = 'long'
86 | else:
87 | caption_type = 'short'
88 | else:
89 | caption_type = self.caption_type
90 | if caption_type == 'long':
91 | caption = item['caption']
92 | prompt = random.choice(self.i2t_long_prompts)
93 | elif caption_type == 'short':
94 | caption = item['short_caption']
95 | prompt = random.choice(self.i2t_short_prompts)
96 | else:
97 | raise ValueError
98 |
99 | if self.output_mode == 'conversation':
100 | if self.shuffle:
101 | prob = random.random()
102 | if prob > self.shuffle_prob:
103 | task_type = 'i2t'
104 | else:
105 | task_type = 't2i'
106 | else:
107 | task_type = self.task_type
108 | if task_type == 'i2t':
109 | instruction = '{} {}'.format(ALL_IMG_TOKENS_STR, prompt)
110 | response = caption
111 | label_masks = [0]
112 | elif task_type == 't2i':
113 | prompt = random.choice(self.t2i_prompts)
114 | instruction = '{} {}'.format(prompt, caption)
115 | response = ALL_IMG_TOKENS_STR
116 | label_masks = [1]
117 | else:
118 | raise ValueError
119 | sources = [{'from': 'human', 'value': instruction},
120 | {'from': 'gpt', 'value': response}]
121 | return {'input_imges': [tgt_img], 'conversation': sources, 'image_label_masks': label_masks}
122 | elif self.output_mode == 'text':
123 | # raise ValueError
124 | if self.shuffle:
125 | prob = random.random()
126 | if prob > self.shuffle_prob:
127 | task_type = 'i2t'
128 | else:
129 | task_type = 't2i'
130 | else:
131 | task_type = self.task_type
132 | if task_type == 'i2t':
133 | text = "{} {}".format(ALL_IMG_TOKENS_STR, caption)
134 | image_label_masks = [0]
135 | elif task_type == 't2i':
136 | text = "{} {}".format(caption, ALL_IMG_TOKENS_STR)
137 | image_label_masks = [1]
138 | else:
139 | raise ValueError
140 | if self.inference:
141 | text = ALL_IMG_TOKENS_STR
142 | return {'input_images': [tgt_img], 'text': text, 'image_label_masks': image_label_masks}
--------------------------------------------------------------------------------
/locals/datasets/image_caption/mmc4.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import json as js
3 | import math
4 | import random
5 | import os
6 | from PIL import Image
7 | from constants import *
8 | import time
9 | import os
10 | from pathlib import Path
11 | class FilteredMMC4Dataset(Dataset):
12 |
13 | def __init__(self,
14 | path: str,
15 | image_folder: str = '',
16 | instruct: bool = False,
17 | min_resize_res: int = 256,
18 | max_resize_res: int = 256,
19 | crop_res: int = 256,
20 | flip_prob: float = 0.5,
21 | sample_weight: float = 1.0,
22 | check: bool = False,
23 | output_mode: str = 'text',
24 | raw_image: bool = False,
25 | meta_folder: str = None,
26 | expand2square: bool = False,
27 | avoid_image_gen: bool=False,
28 | **kwargs):
29 |
30 | time_start = time.time()
31 | self.path = path
32 | self.instruct = instruct
33 | self.meta_folder = meta_folder
34 | self.expand2square = expand2square
35 | self.avoid_image_gen = avoid_image_gen
36 |
37 | if os.path.isdir(path):
38 | self.meta = list(Path(path).rglob('*.json'))
39 | self.need_reload = False
40 | elif os.path.isfile(path) and path.endswith('jsonl'):
41 | self.meta_file = open(path)
42 | self.meta = []
43 | _ = 0
44 | for line in self.meta_file:
45 | self.meta.append(js.loads(line))
46 | _ +=1
47 | self.need_reload = False
48 | elif os.path.isfile(path) and path.endswith('txt'):
49 | self.meta = []
50 | with open(path) as rf:
51 | for line in rf:
52 | self.meta.append(line.strip())
53 | self.need_reload = True
54 | assert self.meta_folder is not None
55 | else:
56 | raise ValueError('Invalid data path')
57 | # self.meta_file = open(path)
58 | # self.meta = []
59 | # _ = 0
60 | # for line in self.meta_file:
61 | # self.meta.append(js.loads(line))
62 | # _ +=1
63 |
64 | #### <<<<< Delete follow code for real training >>>> ####
65 |
66 | # if _ > 5000:
67 | # break
68 |
69 | #########################################################
70 |
71 | self.image_folder = image_folder
72 | self.min_resize_res = min_resize_res
73 | self.max_resize_res = max_resize_res
74 | self.crop_res = crop_res
75 | self.check = check
76 | self.raw_image = raw_image
77 | self.output_mode = output_mode
78 |
79 | self.flip_prob = flip_prob
80 | self.sample_weight = sample_weight
81 | time_end = time.time()
82 | print(f"MMC4 Dataset has {len(self)} samples!!, initialized with {time_end-time_start}")
83 |
84 | def __len__(self):
85 | return int(len(self.meta) * self.sample_weight)
86 |
87 | def __getitem__(self, i):
88 |
89 | if self.sample_weight >= 1:
90 | i = i % len(self.meta)
91 | else:
92 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight))
93 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder)
94 |
95 | item = self.meta[i]
96 | # if item is a path
97 | if isinstance(item,Path):
98 | item = js.load(open(item))
99 | elif self.need_reload:
100 | item = js.load(open(os.path.join(self.meta_folder, item)))
101 |
102 | tgt_img_list = []
103 | img_txt_match_index = []
104 | for img_info in item['image_info']:
105 | tgt_img_list.append(Image.open(os.path.join(self.image_folder,img_info['img_path'])).convert('RGB'))
106 | img_txt_match_index.append(img_info['matched_text_index'])
107 | text_list = [txt for txt in item['text_list']]
108 | img_label_masks = []
109 | for i,match_index in enumerate(img_txt_match_index):
110 | thr = random.random()
111 | if i == 0 and thr <= 0.5:
112 | img_label_masks.append(0)
113 | else:
114 | if self.avoid_image_gen:
115 | img_label_masks.append(0)
116 | else:
117 | img_label_masks.append(1)
118 | text_list[match_index] = "{} {}".format(text_list[match_index], ALL_IMG_TOKENS_STR) if thr > 0.5 else "{} {}".format(ALL_IMG_TOKENS_STR, text_list[match_index])
119 | # return image_0, image_1, instruction
120 |
121 | assert self.output_mode == 'text'
122 | return {'input_images': tgt_img_list, 'text': ' '.join(text_list) , 'image_label_masks': img_label_masks}
123 |
124 |
125 | if __name__ == "__main__":
126 | dataset = FilteredMMC4Dataset('/mnt/bn/luoruipu-disk/meta_data/mmc4/filter_mmc4_meta_with_img_abs_path.jsonl')
127 | print(repr(dataset[1000]))
--------------------------------------------------------------------------------
/locals/datasets/image_edit/low_level/clwd.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InstructDiffusion
3 | # Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
4 | # Modified by Chen Li (edward82@stu.xjtu.edu.cn)
5 | # --------------------------------------------------------
6 |
7 | import os
8 | import numpy as np
9 | from torch.utils.data import Dataset
10 | import torch
11 | from PIL import Image
12 | import torchvision.transforms.functional as TF
13 | from pdb import set_trace as stx
14 | import random
15 | import cv2
16 | from PIL import Image
17 | import torchvision
18 | from constants import ALL_IMG_TOKENS_STR
19 |
20 |
21 | def is_image_file(filename):
22 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
23 |
24 |
25 | class CLWD(Dataset):
26 | def __init__(self, path, split="train", size=256, interpolation="pil_lanczos", check=False, raw_image=False,
27 | flip_prob=0.5, sample_weight=1.0, instruct=False, prompt_path='datasets/image_edit/prompt/prompt_dewatermark.txt'):
28 | super(CLWD, self).__init__()
29 |
30 | inp_files = sorted(os.listdir(os.path.join(path, split, 'Watermarked_image')))
31 | tar_files = sorted(os.listdir(os.path.join(path, split, 'Watermark_free_image')))
32 |
33 | self.inp_filenames = [os.path.join(path, split, 'Watermarked_image', x) for x in inp_files if is_image_file(x)]
34 | self.tar_filenames = [os.path.join(path, split, 'Watermark_free_image', x) for x in tar_files if is_image_file(x)]
35 |
36 | self.size = size
37 | self.flip_prob = flip_prob
38 | self.sample_weight = sample_weight
39 | self.instruct = instruct
40 | self.check = check
41 | self.raw_image = raw_image
42 | self.sizex = len(self.tar_filenames) # get the size of target
43 |
44 | self.interpolation = {
45 | "cv_nearest": cv2.INTER_NEAREST,
46 | "cv_bilinear": cv2.INTER_LINEAR,
47 | "cv_bicubic": cv2.INTER_CUBIC,
48 | "cv_area": cv2.INTER_AREA,
49 | "cv_lanczos": cv2.INTER_LANCZOS4,
50 | "pil_nearest": Image.NEAREST,
51 | "pil_bilinear": Image.BILINEAR,
52 | "pil_bicubic": Image.BICUBIC,
53 | "pil_box": Image.BOX,
54 | "pil_hamming": Image.HAMMING,
55 | "pil_lanczos": Image.LANCZOS,
56 | }[interpolation]
57 |
58 | # prompt_path='dataset/prompt/prompt_dewatermark.txt'
59 | self.prompt_list=[]
60 | with open(prompt_path) as f:
61 | line=f.readline()
62 | while line:
63 | line=line.strip('\n')
64 | self.prompt_list.append(line)
65 | line=f.readline()
66 |
67 | print(f"CLWD has {len(self)} samples!!")
68 |
69 | def __len__(self):
70 | return int(self.sizex * self.sample_weight)
71 |
72 | def __getitem__(self, index):
73 | if self.raw_image:
74 | return self.get_raw_image(index)
75 | else:
76 | return self.get_processed_image(index)
77 |
78 | def get_processed_image(self, index):
79 | if self.sample_weight >= 1:
80 | index_ = index % self.sizex
81 | else:
82 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
83 |
84 | inp_path = self.inp_filenames[index_]
85 | tar_path = self.tar_filenames[index_]
86 |
87 | inp_img = Image.open(inp_path)
88 | tar_img = Image.open(tar_path)
89 |
90 | width, height = inp_img.size
91 | tar_width, tar_height = tar_img.size
92 | assert tar_width == width and tar_height == height, "Input and target image mismatch"
93 | aspect_ratio = float(width) / float(height)
94 | if width < height:
95 | new_width = self.size
96 | new_height = int(self.size / aspect_ratio)
97 | else:
98 | new_height = self.size
99 | new_width = int(self.size * aspect_ratio)
100 | inp_img = inp_img.resize((new_width, new_height), self.interpolation)
101 | tar_img = tar_img.resize((new_width, new_height), self.interpolation)
102 |
103 | inp_img = np.array(inp_img).astype(np.float32).transpose(2, 0, 1)
104 | inp_img_tensor = torch.tensor((inp_img / 127.5 - 1.0).astype(np.float32))
105 | tar_img = np.array(tar_img).astype(np.float32).transpose(2, 0, 1)
106 | tar_img_tensor = torch.tensor((tar_img / 127.5 - 1.0).astype(np.float32))
107 | crop = torchvision.transforms.RandomCrop(self.size)
108 | flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
109 | image_0, image_1 = flip(crop(torch.cat((inp_img_tensor, tar_img_tensor)))).chunk(2)
110 |
111 | prompt = random.choice(self.prompt_list)
112 | if self.instruct:
113 | prompt = "Watermark Removal: " + prompt
114 |
115 | if self.check:
116 | return dict(edited=Image.open(tar_path), edit=dict(source=Image.open(inp_path), instruction=prompt))
117 | return dict(edited=image_1, edit=dict(source=image_0, instruction=prompt))
118 |
119 | def get_raw_image(self, index):
120 | if self.sample_weight >= 1:
121 | index_ = index % self.sizex
122 | else:
123 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
124 |
125 | inp_path = self.inp_filenames[index_]
126 | tar_path = self.tar_filenames[index_]
127 |
128 | inp_img = Image.open(inp_path).convert('RGB')
129 | tar_img = Image.open(tar_path).convert('RGB')
130 |
131 | width, height = inp_img.size
132 | tar_width, tar_height = tar_img.size
133 | assert tar_width == width and tar_height == height, "Input and target image mismatch"
134 |
135 | prompt = random.choice(self.prompt_list)
136 | if self.instruct:
137 | prompt = "Watermark Removal: " + prompt
138 |
139 | sources = [{'from': 'human', 'value': '{}:
.'.format(prompt)},
140 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}]
141 | return {'input_images': [inp_img], 'output_images': [tar_img], 'output_cond_images': [inp_img], 'conversation': sources, 'image_label_masks': [0,1]}
--------------------------------------------------------------------------------
/locals/datasets/image_edit/low_level/gopro.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InstructDiffusion
3 | # Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
4 | # Modified by Chen Li (edward82@stu.xjtu.edu.cn)
5 | # --------------------------------------------------------
6 |
7 | import os
8 | import numpy as np
9 | from torch.utils.data import Dataset
10 | import torch
11 | from PIL import Image
12 | import torchvision.transforms.functional as TF
13 | from pdb import set_trace as stx
14 | import random
15 | import cv2
16 | from PIL import Image
17 | import torchvision
18 | from constants import ALL_IMG_TOKENS_STR
19 |
20 |
21 | def is_image_file(filename):
22 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
23 |
24 |
25 | class GoPro(Dataset):
26 | def __init__(self, path, split="train", size=256, interpolation="pil_lanczos", check=False, raw_image=False,
27 | flip_prob=0.5, sample_weight=1.0, instruct=False, prompt_path='datasets/image_edit/prompt/prompt_deblur.txt'):
28 | super(GoPro, self).__init__()
29 |
30 | # inp_files = sorted(os.listdir(os.path.join(path, split, 'input')))
31 | # tar_files = sorted(os.listdir(os.path.join(path, split, 'target')))
32 |
33 | # self.inp_filenames = [os.path.join(path, split, 'input', x) for x in inp_files if is_image_file(x)]
34 | # self.tar_filenames = [os.path.join(path, split, 'target', x) for x in tar_files if is_image_file(x)]
35 |
36 | filenames_splits = sorted(os.listdir(os.path.join(path, split)))
37 |
38 | self.inp_filenames = [os.path.join(path, split, d, 'blur', x) for d in filenames_splits for x in sorted(os.listdir(os.path.join(path, split, d, 'blur'))) if is_image_file(x)]
39 | self.tar_filenames = [os.path.join(path, split, d, 'sharp', x) for d in filenames_splits for x in sorted(os.listdir(os.path.join(path, split, d, 'sharp'))) if is_image_file(x)]
40 |
41 | self.size = size
42 | self.flip_prob = flip_prob
43 | self.sample_weight = sample_weight
44 | self.instruct = instruct
45 | self.check = check
46 | self.raw_image = raw_image
47 | self.sizex = len(self.tar_filenames) # get the size of target
48 |
49 | self.interpolation = {
50 | "cv_nearest": cv2.INTER_NEAREST,
51 | "cv_bilinear": cv2.INTER_LINEAR,
52 | "cv_bicubic": cv2.INTER_CUBIC,
53 | "cv_area": cv2.INTER_AREA,
54 | "cv_lanczos": cv2.INTER_LANCZOS4,
55 | "pil_nearest": Image.NEAREST,
56 | "pil_bilinear": Image.BILINEAR,
57 | "pil_bicubic": Image.BICUBIC,
58 | "pil_box": Image.BOX,
59 | "pil_hamming": Image.HAMMING,
60 | "pil_lanczos": Image.LANCZOS,
61 | }[interpolation]
62 |
63 | # prompt_path='dataset/prompt/prompt_deblur.txt'
64 | self.prompt_list=[]
65 | with open(prompt_path) as f:
66 | line=f.readline()
67 | while line:
68 | line=line.strip('\n')
69 | self.prompt_list.append(line)
70 | line=f.readline()
71 |
72 | print(f"GoPro has {len(self)} samples!!")
73 |
74 | def __len__(self):
75 | return int(self.sizex * self.sample_weight)
76 |
77 | def __getitem__(self, index):
78 | if self.raw_image:
79 | return self.get_raw_image(index)
80 | else:
81 | return self.get_processed_image(index)
82 |
83 | def get_processed_image(self, index):
84 | if self.sample_weight >= 1:
85 | index_ = index % self.sizex
86 | else:
87 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
88 |
89 | inp_path = self.inp_filenames[index_]
90 | tar_path = self.tar_filenames[index_]
91 |
92 | inp_img = Image.open(inp_path)
93 | tar_img = Image.open(tar_path)
94 |
95 | width, height = inp_img.size
96 | tar_width, tar_height = tar_img.size
97 | assert tar_width == width and tar_height == height, "Input and target image mismatch"
98 | aspect_ratio = float(width) / float(height)
99 | if width < height:
100 | new_width = self.size
101 | new_height = int(self.size / aspect_ratio)
102 | else:
103 | new_height = self.size
104 | new_width = int(self.size * aspect_ratio)
105 | inp_img = inp_img.resize((new_width, new_height), self.interpolation)
106 | tar_img = tar_img.resize((new_width, new_height), self.interpolation)
107 |
108 | inp_img = np.array(inp_img).astype(np.float32).transpose(2, 0, 1)
109 | inp_img_tensor = torch.tensor((inp_img / 127.5 - 1.0).astype(np.float32))
110 | tar_img = np.array(tar_img).astype(np.float32).transpose(2, 0, 1)
111 | tar_img_tensor = torch.tensor((tar_img / 127.5 - 1.0).astype(np.float32))
112 | crop = torchvision.transforms.RandomCrop(self.size)
113 | flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
114 | image_0, image_1 = flip(crop(torch.cat((inp_img_tensor, tar_img_tensor)))).chunk(2)
115 |
116 | prompt = random.choice(self.prompt_list)
117 | if self.instruct:
118 | prompt = "Image Deblurring: " + prompt
119 |
120 | if self.check:
121 | return dict(edited=Image.open(tar_path), edit=dict(source=Image.open(inp_path), instruction=prompt))
122 | return dict(edited=image_1, edit=dict(source=image_0, instruction=prompt))
123 |
124 | def get_raw_image(self, index):
125 | if self.sample_weight >= 1:
126 | index_ = index % self.sizex
127 | else:
128 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
129 |
130 | inp_path = self.inp_filenames[index_]
131 | tar_path = self.tar_filenames[index_]
132 |
133 | inp_img = Image.open(inp_path).convert('RGB')
134 | tar_img = Image.open(tar_path).convert('RGB')
135 |
136 | width, height = inp_img.size
137 | tar_width, tar_height = tar_img.size
138 | assert tar_width == width and tar_height == height, "Input and target image mismatch"
139 |
140 | prompt = random.choice(self.prompt_list)
141 | if self.instruct:
142 | prompt = "Image Deblurring: " + prompt
143 |
144 | sources = [{'from': 'human', 'value': '{}:
.'.format(prompt)},
145 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}]
146 | return {'input_images': [inp_img], 'output_images': [tar_img], 'output_cond_images': [inp_img], 'conversation': sources, 'image_label_masks': [0, 1]}
--------------------------------------------------------------------------------
/locals/datasets/image_edit/low_level/reds.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from torch.utils.data import Dataset
4 | import torch
5 | from PIL import Image
6 | import torchvision.transforms.functional as TF
7 | from pdb import set_trace as stx
8 | import random
9 | import cv2
10 | from PIL import Image
11 | import torchvision
12 | from constants import ALL_IMG_TOKENS_STR
13 |
14 |
15 | def is_image_file(filename):
16 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
17 |
18 |
19 | class REDS(Dataset):
20 | def __init__(self, path, split="train", size=256, interpolation="pil_lanczos", check=False, raw_image=False,
21 | flip_prob=0.5, sample_weight=1.0, instruct=False, prompt_path='datasets/image_edit/prompt/prompt_deblur.txt'):
22 | super(REDS, self).__init__()
23 |
24 | inp_files = sorted(os.listdir(os.path.join(path, split, 'train_blur')))
25 | tar_files = sorted(os.listdir(os.path.join(path, split, 'train_sharp')))
26 |
27 | if split == "train":
28 | self.inp_filenames = [os.path.join(path, split, 'train_blur', d, x) for d in inp_files for x in sorted(os.listdir(os.path.join(path, split, 'train_blur', d))) if is_image_file(x)]
29 | self.tar_filenames = [os.path.join(path, split, 'train_sharp', d, x) for d in tar_files for x in sorted(os.listdir(os.path.join(path, split, 'train_sharp', d))) if is_image_file(x)]
30 | else:
31 | self.inp_filenames = [os.path.join(path, split, 'blur', x) for x in inp_files if is_image_file(x)]
32 | self.tar_filenames = [os.path.join(path, split, 'sharp', x) for x in tar_files if is_image_file(x)]
33 |
34 | self.size = size
35 | self.flip_prob = flip_prob
36 | self.sample_weight = sample_weight
37 | self.instruct = instruct
38 | self.raw_image = raw_image
39 | assert len(self.inp_filenames) == len(self.tar_filenames)
40 | self.check = check
41 | self.sizex = len(self.tar_filenames) # get the size of target
42 |
43 | self.interpolation = {
44 | "cv_nearest": cv2.INTER_NEAREST,
45 | "cv_bilinear": cv2.INTER_LINEAR,
46 | "cv_bicubic": cv2.INTER_CUBIC,
47 | "cv_area": cv2.INTER_AREA,
48 | "cv_lanczos": cv2.INTER_LANCZOS4,
49 | "pil_nearest": Image.NEAREST,
50 | "pil_bilinear": Image.BILINEAR,
51 | "pil_bicubic": Image.BICUBIC,
52 | "pil_box": Image.BOX,
53 | "pil_hamming": Image.HAMMING,
54 | "pil_lanczos": Image.LANCZOS,
55 | }[interpolation]
56 |
57 | self.prompt_list=[]
58 | with open(prompt_path) as f:
59 | line=f.readline()
60 | while line:
61 | line=line.strip('\n')
62 | self.prompt_list.append(line)
63 | line=f.readline()
64 |
65 | print(f"REDS has {len(self)} samples!!")
66 |
67 | def __len__(self):
68 | return int(self.sizex * self.sample_weight)
69 |
70 | def __getitem__(self, index):
71 | if self.raw_image:
72 | return self.get_raw_image(index)
73 | else:
74 | return self.get_processed_image(index)
75 |
76 | def get_processed_image(self, index):
77 | if self.sample_weight >= 1:
78 | index_ = index % self.sizex
79 | else:
80 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
81 |
82 | inp_path = self.inp_filenames[index_]
83 | tar_path = self.tar_filenames[index_]
84 |
85 | inp_img = Image.open(inp_path)
86 | tar_img = Image.open(tar_path)
87 |
88 | width, height = inp_img.size
89 | tar_width, tar_height = tar_img.size
90 | assert tar_width == width and tar_height == height, "Input and target image mismatch"
91 | aspect_ratio = float(width) / float(height)
92 | if width < height:
93 | new_width = self.size
94 | new_height = int(self.size / aspect_ratio)
95 | else:
96 | new_height = self.size
97 | new_width = int(self.size * aspect_ratio)
98 | inp_img = inp_img.resize((new_width, new_height), self.interpolation)
99 | tar_img = tar_img.resize((new_width, new_height), self.interpolation)
100 |
101 | inp_img = np.array(inp_img).astype(np.float32).transpose(2, 0, 1)
102 | inp_img_tensor = torch.tensor((inp_img / 127.5 - 1.0).astype(np.float32))
103 | tar_img = np.array(tar_img).astype(np.float32).transpose(2, 0, 1)
104 | tar_img_tensor = torch.tensor((tar_img / 127.5 - 1.0).astype(np.float32))
105 | crop = torchvision.transforms.RandomCrop(self.size)
106 | flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
107 | image_0, image_1 = flip(crop(torch.cat((inp_img_tensor, tar_img_tensor)))).chunk(2)
108 |
109 | prompt = random.choice(self.prompt_list)
110 | if self.instruct:
111 | prompt = "Image Deblurring: " + prompt
112 |
113 | if self.check:
114 | return dict(edited=Image.open(tar_path), edit=dict(source=Image.open(inp_path), instruction=prompt))
115 | return dict(edited=image_1, edit=dict(source=image_0, instruction=prompt))
116 |
117 | def get_raw_image(self, index):
118 | if self.sample_weight >= 1:
119 | index_ = index % self.sizex
120 | else:
121 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
122 |
123 | inp_path = self.inp_filenames[index_]
124 | tar_path = self.tar_filenames[index_]
125 |
126 | inp_img = Image.open(inp_path).convert('RGB')
127 | tar_img = Image.open(tar_path).convert('RGB')
128 |
129 | width, height = inp_img.size
130 | tar_width, tar_height = tar_img.size
131 | assert tar_width == width and tar_height == height, "Input and target image mismatch"
132 |
133 | prompt = random.choice(self.prompt_list)
134 | if self.instruct:
135 | prompt = "Image Deblurring: " + prompt
136 |
137 | sources = [{'from': 'human', 'value': '{}:
.'.format(prompt)},
138 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}]
139 | return {'input_images': [inp_img], 'output_images': [tar_img], 'output_cond_images': [inp_img], 'conversation': sources, 'image_label_masks': [0, 1]}
--------------------------------------------------------------------------------
/locals/datasets/image_edit/low_level/sidd.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InstructDiffusion
3 | # Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
4 | # Modified by Chen Li (edward82@stu.xjtu.edu.cn)
5 | # --------------------------------------------------------
6 |
7 | import os
8 | import numpy as np
9 | from torch.utils.data import Dataset
10 | import torch
11 | from PIL import Image
12 | import torchvision.transforms.functional as TF
13 | from pdb import set_trace as stx
14 | import random
15 | import cv2
16 | from PIL import Image
17 | import torchvision
18 | from constants import ALL_IMG_TOKENS_STR
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
22 |
23 |
24 | class SIDD(Dataset):
25 | def __init__(self, path, split="train", size=256, interpolation="pil_lanczos", check=False, raw_image=False,
26 | flip_prob=0.5, sample_weight=1.0, instruct=False, prompt_path='datasets/image_edit/prompt/prompt_denoise.txt'):
27 | super(SIDD, self).__init__()
28 |
29 | # inp_files = sorted(os.listdir(os.path.join(path, split, 'input')))
30 | # tar_files = sorted(os.listdir(os.path.join(path, split, 'gt')))
31 |
32 | # self.inp_filenames = [os.path.join(path, split, 'input', x) for x in inp_files if is_image_file(x)]
33 | # self.tar_filenames = [os.path.join(path, split, 'gt', x) for x in tar_files if is_image_file(x)]
34 |
35 | filenames_splits = sorted(os.listdir(os.path.join(path)))
36 |
37 | self.inp_filenames = [os.path.join(path, d, x) for d in filenames_splits for x in sorted(os.listdir(os.path.join(path, d))) if is_image_file(x) and 'NOISY' in x]
38 | self.tar_filenames = [os.path.join(path, d, x) for d in filenames_splits for x in sorted(os.listdir(os.path.join(path, d))) if is_image_file(x) and 'GT' in x]
39 |
40 | self.size = size
41 | self.flip_prob = flip_prob
42 | self.sample_weight = sample_weight
43 | self.instruct = instruct
44 | self.check = check
45 | self.raw_image = raw_image
46 | self.sizex = len(self.tar_filenames) # get the size of target
47 |
48 | self.interpolation = {
49 | "cv_nearest": cv2.INTER_NEAREST,
50 | "cv_bilinear": cv2.INTER_LINEAR,
51 | "cv_bicubic": cv2.INTER_CUBIC,
52 | "cv_area": cv2.INTER_AREA,
53 | "cv_lanczos": cv2.INTER_LANCZOS4,
54 | "pil_nearest": Image.NEAREST,
55 | "pil_bilinear": Image.BILINEAR,
56 | "pil_bicubic": Image.BICUBIC,
57 | "pil_box": Image.BOX,
58 | "pil_hamming": Image.HAMMING,
59 | "pil_lanczos": Image.LANCZOS,
60 | }[interpolation]
61 |
62 | # prompt_path='dataset/prompt/prompt_denoise.txt'
63 | self.prompt_list=[]
64 | with open(prompt_path) as f:
65 | line=f.readline()
66 | while line:
67 | line=line.strip('\n')
68 | self.prompt_list.append(line)
69 | line=f.readline()
70 | print(f"SIDD has {len(self)} samples!!")
71 |
72 | def __len__(self):
73 | return int(self.sizex * self.sample_weight)
74 |
75 | def __getitem__(self, index):
76 | if self.raw_image:
77 | return self.get_raw_image(index)
78 | else:
79 | return self.get_processed_image(index)
80 |
81 | def get_processed_image(self, index):
82 | if self.sample_weight >= 1:
83 | index_ = index % self.sizex
84 | else:
85 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
86 |
87 | inp_path = self.inp_filenames[index_]
88 | tar_path = self.tar_filenames[index_]
89 |
90 | inp_img = Image.open(inp_path)
91 | tar_img = Image.open(tar_path)
92 |
93 | width, height = inp_img.size
94 | tar_width, tar_height = tar_img.size
95 | assert tar_width == width and tar_height == height, "Input and target image mismatch"
96 |
97 | inp_img = np.array(inp_img).astype(np.float32).transpose(2, 0, 1)
98 | inp_img_tensor = torch.tensor((inp_img / 127.5 - 1.0).astype(np.float32))
99 | tar_img = np.array(tar_img).astype(np.float32).transpose(2, 0, 1)
100 | tar_img_tensor = torch.tensor((tar_img / 127.5 - 1.0).astype(np.float32))
101 | crop = torchvision.transforms.RandomCrop(self.size)
102 | flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
103 | image_0, image_1 = flip(crop(torch.cat((inp_img_tensor, tar_img_tensor)))).chunk(2)
104 |
105 | prompt = random.choice(self.prompt_list)
106 | if self.instruct:
107 | prompt = "Image Denoising: " + prompt
108 |
109 | if self.check:
110 | return dict(edited=Image.open(tar_path), edit=dict(source=Image.open(inp_path), instruction=prompt))
111 | return dict(edited=image_1, edit=dict(source=image_0, instruction=prompt))
112 |
113 | def get_raw_image(self, index):
114 | if self.sample_weight >= 1:
115 | index_ = index % self.sizex
116 | else:
117 | index_ = int(index / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1)
118 |
119 | inp_path = self.inp_filenames[index_]
120 | tar_path = self.tar_filenames[index_]
121 |
122 | inp_img = Image.open(inp_path).convert('RGB')
123 | tar_img = Image.open(tar_path).convert('RGB')
124 |
125 | width, height = inp_img.size
126 | tar_width, tar_height = tar_img.size
127 | assert tar_width == width and tar_height == height, "Input and target image mismatch"
128 |
129 | prompt = random.choice(self.prompt_list)
130 | if self.instruct:
131 | prompt = "Image Denoising: " + prompt
132 |
133 | sources = [{'from': 'human', 'value': '{}:
.'.format(prompt)},
134 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}]
135 | return {'input_images': [inp_img, tar_img], 'output_images': [tar_img], 'output_cond_images': [inp_img], 'conversation': sources, 'image_label_masks': [0, 1]}
--------------------------------------------------------------------------------
/locals/datasets/image_edit/prompt/color_list_train_small.txt:
--------------------------------------------------------------------------------
1 | Red 纯红 #FF0000 255,0,0
2 |
3 | Purple 紫色 #800080 128,0,128
4 |
5 | Blue 纯蓝 #0000FF 0,0,255
6 |
7 | Green 纯绿 #008000 0,128,0
8 |
9 | Yellow 纯黄 #FFFF00 255,255,0
10 |
11 | White 纯白 #FFFFFF 255,255,255
12 |
13 | Black 纯黑 #000000 0,0,0
14 |
15 | Gray 灰色 #808080 128,128,128
16 |
17 |
18 |
--------------------------------------------------------------------------------
/locals/datasets/image_edit/prompt/prompt_deblur.txt:
--------------------------------------------------------------------------------
1 | Sharpen this blurry image
2 | Increase the sharpness of this unclear photo
3 | Correct the lack of focus in this misty picture
4 | Heighten the definition of this smeared image
5 | Clear up this fuzzy picture
6 | Refine this indistinct photograph
7 | Improve the focus of this hazy image
8 | Amend the softness of this out-of-focus photograph
9 | Polish the murkiness of this low-definition photo
10 | Rectify the vagueness of this blurred image
--------------------------------------------------------------------------------
/locals/datasets/image_edit/prompt/prompt_denoise.txt:
--------------------------------------------------------------------------------
1 | Remove noise from this image
2 | Eliminate the noise in this picture
3 | Purify this photo by removing noise
4 | Clear up the image by filtering out noise
5 | Eradicate the noise from this photograph
6 | Minimize the noise present in this picture
7 | Cancel out the noise within this image
8 | Clean this photo by discarding the noise
9 | Suppress the noise in this visual representation
10 | Rectify the noise interference in this image
--------------------------------------------------------------------------------
/locals/datasets/image_edit/prompt/prompt_dewatermark.txt:
--------------------------------------------------------------------------------
1 | Remove watermark from this picture
2 | Erase the watermark from this photograph.
3 | Extract the watermark from this image.
4 | Take out the watermark overlay from this photo.
5 | Wipe off the watermark imprint on this image.
6 | Detach the watermark from this visual representation.
7 | Get rid of the watermarking on this picture.
8 | Withdraw the watermark applied to this photograph.
9 | Clean up this image by deleting the watermark.
10 | Unmark this photo by removing the watermark.
--------------------------------------------------------------------------------
/locals/datasets/image_edit/prompt/prompt_pose.txt:
--------------------------------------------------------------------------------
1 | Circle the {joint} of the people with the color {color},
2 | Use the {color} color to draw circles around the {joint} of the people,
3 | Make {color} circles around the {joint} of the people,
4 | Put {color} circles on the {joint} of the people,
5 | Draw {color} circles over the {joint} of the people,
6 | Surround the {joint} of the people with {color} circles,
7 | Use the color {color} to make circles on the {joint} of the people,
8 | Mark the {joint} of the people with {color} circles,
9 | Create {color} circles around the {joint} of the people,
10 | Use the color {color} to encircle the {joint} of the people,
--------------------------------------------------------------------------------
/locals/datasets/image_edit/prompt/prompt_seg.txt:
--------------------------------------------------------------------------------
1 | Mark the pixels of {object} in {color} and leave the rest unchanged.
2 | Color the {object}'s pixels in {color}, keeping the remaining pixels unaltered.
3 | Apply {color} to the pixels of {object} while maintaining the current state of other pixels.
4 | Assign {color} to the pixels belonging to {object}, preserving the rest as they are.
5 | For {object}, set its pixels to {color} and let the others remain the same.
6 | Modify the pixels of {object} to {color} without affecting any other pixels.
7 | Set the {object} pixels to {color} and keep the other pixels in their original state.
8 | Update the pixels of {object} to {color}, but leave the other pixels untouched.
9 | Fill in the pixels of {object} with {color}, retaining the existing colors of the remaining pixels.
10 | Change the {object} pixels to {color}, while keeping the other pixels constant.
11 | Paint the pixels of {object} in {color} and maintain the current appearance of the other pixels.
--------------------------------------------------------------------------------
/locals/datasets/image_edit/seg/grefcoco_seg.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InstructDiffusion
3 | # Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
4 | # Modified by Binxin Yang (tennyson@mail.ustc.edu.cn)
5 | # --------------------------------------------------------
6 |
7 | from __future__ import annotations
8 |
9 | import os
10 | import random
11 | import copy
12 | import json
13 | import math
14 | from pathlib import Path
15 | from typing import Any
16 |
17 | import numpy as np
18 | import torch
19 | import torchvision
20 | from einops import rearrange
21 | from PIL import Image
22 | from torch.utils.data import Dataset
23 |
24 | from .grefcoco import G_REFER
25 | from constants import ALL_IMG_TOKENS_STR
26 |
27 |
28 | class GrefCOCODataset(Dataset):
29 | def __init__(
30 | self,
31 | path: str,
32 | split: str = "train",
33 | min_resize_res: int = 256,
34 | max_resize_res: int = 256,
35 | crop_res: int = 256,
36 | flip_prob: float = 0.0,
37 | transparency: float = 0.0,
38 | test: bool = False,
39 | image_path: str = None
40 | ):
41 | assert split in ("train", "val", "test")
42 | self.path = path
43 | self.min_resize_res = min_resize_res
44 | self.max_resize_res = max_resize_res
45 | self.crop_res = crop_res
46 | self.flip_prob = flip_prob
47 | self.G_ref_dataset=G_REFER(data_root=path)
48 | self.IMAGE_DIR = os.path.join(image_path, 'train2014')
49 | self.list_ref=self.G_ref_dataset.getRefIds(split=split)
50 | self.transparency = transparency
51 | self.test = test
52 |
53 | seg_diverse_prompt_path = 'datasets/image_edit/prompt/prompt_seg.txt'
54 | self.seg_diverse_prompt_list=[]
55 | with open(seg_diverse_prompt_path) as f:
56 | line=f.readline()
57 | while line:
58 | line=line.strip('\n')
59 | self.seg_diverse_prompt_list.append(line)
60 | line=f.readline()
61 |
62 | color_list_file_path='datasets/image_edit/prompt/color_list_train_small.txt'
63 | self.color_list=[]
64 | with open(color_list_file_path) as f:
65 | line = f.readline()
66 | while line:
67 | line_split = line.strip('\n').split(" ")
68 | if len(line_split)>1:
69 | temp = []
70 | for i in range(4):
71 | temp.append(line_split[i])
72 | self.color_list.append(temp)
73 | line = f.readline()
74 |
75 | def __len__(self) -> int:
76 | return len(self.list_ref)
77 |
78 | def _augmentation_new(self, image, label):
79 |
80 | # Cropping
81 | h, w = label.shape
82 | if h > w:
83 | start_h = random.randint(0, h - w)
84 | end_h = start_h + w
85 | image = image[start_h:end_h]
86 | label = label[start_h:end_h]
87 | elif h < w:
88 | start_w = random.randint(0, w - h)
89 | end_w = start_w + h
90 | image = image[:, start_w:end_w]
91 | label = label[:, start_w:end_w]
92 | else:
93 | pass
94 | image = Image.fromarray(image).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.LANCZOS)
95 | image = np.asarray(image, dtype=np.uint8)
96 | label = Image.fromarray(label).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.NEAREST)
97 | label = np.asarray(label, dtype=np.int64)
98 | return image, label
99 |
100 | def __getitem__(self, i: int) -> dict[str, Any]:
101 |
102 | ref_ids = self.list_ref[i]
103 | ref = self.G_ref_dataset.loadRefs(ref_ids)[0]
104 | sentences = random.choice(ref['sentences'])['sent']
105 |
106 | prompt = random.choice(self.seg_diverse_prompt_list)
107 |
108 | color = random.choice(self.color_list)
109 | color_name = color[0]
110 | prompt = prompt.format(color=color_name.lower(), object=sentences.lower())
111 |
112 | R, G, B = color[3].split(",")
113 | R = int(R)
114 | G = int(G)
115 | B = int(B)
116 |
117 | image_name = self.G_ref_dataset.loadImgs(ref['image_id'])[0]['file_name']
118 | image_path = os.path.join(self.IMAGE_DIR,image_name)
119 | mask = self.G_ref_dataset.getMaskByRef(ref=ref,merge=True)['mask']
120 |
121 | image = Image.open(image_path).convert("RGB")
122 | image = np.asarray(image)
123 |
124 | # image, mask = self._augmentation_new(image,mask)
125 |
126 | mask = (mask == 1)
127 |
128 | image_0 = Image.fromarray(image)
129 | image_1 = copy.deepcopy(image)
130 | image_1[:,:,0][mask]=self.transparency*image_1[:,:,0][mask]+(1-self.transparency)*R
131 | image_1[:,:,1][mask]=self.transparency*image_1[:,:,1][mask]+(1-self.transparency)*G
132 | image_1[:,:,2][mask]=self.transparency*image_1[:,:,2][mask]+(1-self.transparency)*B
133 | image_1 = Image.fromarray(image_1)
134 |
135 | reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item()
136 | image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
137 | image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
138 |
139 |
140 | sources = [{'from': 'human', 'value': '{}:
.'.format(prompt)},
141 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}]
142 |
143 | return {'input_images': [image_0], 'output_images': [image_1],
144 | 'output_cond_images': [image_0], 'conversation': sources, 'image_label_masks': [0, 1]}
--------------------------------------------------------------------------------
/locals/datasets/image_edit/seg/refcoco_seg.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # InstructDiffusion
3 | # Based on instruct-pix2pix (https://github.com/timothybrooks/instruct-pix2pix)
4 | # Modified by Binxin Yang (tennyson@mail.ustc.edu.cn)
5 | # --------------------------------------------------------
6 |
7 | from __future__ import annotations
8 |
9 | import os
10 | import random
11 | import copy
12 | import json
13 | import math
14 | from pathlib import Path
15 | from typing import Any
16 |
17 | import numpy as np
18 | import torch
19 | import torchvision
20 | from einops import rearrange
21 | from PIL import Image
22 | from torch.utils.data import Dataset
23 |
24 | from .refcoco import REFER
25 | from constants import ALL_IMG_TOKENS_STR
26 |
27 |
28 | class RefCOCODataset(Dataset):
29 | def __init__(
30 | self,
31 | path: str,
32 | split: str = "train",
33 | min_resize_res: int = 256,
34 | max_resize_res: int = 256,
35 | crop_res: int = 256,
36 | flip_prob: float = 0.0,
37 | transparency: float = 0.0,
38 | test: bool = False,
39 | image_path: str = None,
40 | ):
41 | assert split in ("train", "val", "test")
42 | self.path = path
43 | self.min_resize_res = min_resize_res
44 | self.max_resize_res = max_resize_res
45 | self.crop_res = crop_res
46 | self.flip_prob = flip_prob
47 | self.G_ref_dataset=REFER(data_root=path)
48 | self.IMAGE_DIR = os.path.join(image_path, 'train2014')
49 | self.list_ref=self.G_ref_dataset.getRefIds(split=split)
50 | self.transparency = transparency
51 | self.test = test
52 |
53 | seg_diverse_prompt_path = 'datasets/image_edit/prompt/prompt_seg.txt'
54 | self.seg_diverse_prompt_list=[]
55 | with open(seg_diverse_prompt_path) as f:
56 | line=f.readline()
57 | while line:
58 | line=line.strip('\n')
59 | self.seg_diverse_prompt_list.append(line)
60 | line=f.readline()
61 |
62 | color_list_file_path='datasets/image_edit/prompt/color_list_train_small.txt'
63 | self.color_list=[]
64 | with open(color_list_file_path) as f:
65 | line = f.readline()
66 | while line:
67 | line_split = line.strip('\n').split(" ")
68 | if len(line_split)>1:
69 | temp = []
70 | for i in range(4):
71 | temp.append(line_split[i])
72 | self.color_list.append(temp)
73 | line = f.readline()
74 |
75 | def __len__(self) -> int:
76 | return len(self.list_ref)
77 |
78 | def _augmentation_new(self, image, label):
79 |
80 | # Cropping
81 | h, w = label.shape
82 | if h > w:
83 | start_h = random.randint(0, h - w)
84 | end_h = start_h + w
85 | image = image[start_h:end_h]
86 | label = label[start_h:end_h]
87 | elif h < w:
88 | start_w = random.randint(0, w - h)
89 | end_w = start_w + h
90 | image = image[:, start_w:end_w]
91 | label = label[:, start_w:end_w]
92 | else:
93 | pass
94 | image = Image.fromarray(image).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.LANCZOS)
95 | image = np.asarray(image, dtype=np.uint8)
96 | label = Image.fromarray(label).resize((self.min_resize_res, self.min_resize_res), resample=Image.Resampling.NEAREST)
97 | label = np.asarray(label, dtype=np.int64)
98 | return image, label
99 |
100 | def __getitem__(self, i: int) -> dict[str, Any]:
101 |
102 | ref_ids = self.list_ref[i]
103 | ref = self.G_ref_dataset.loadRefs(ref_ids)[0]
104 | sentences = random.choice(ref['sentences'])['sent']
105 |
106 | prompt = random.choice(self.seg_diverse_prompt_list)
107 |
108 | color = random.choice(self.color_list)
109 | color_name = color[0]
110 | prompt = prompt.format(color=color_name.lower(), object=sentences.lower())
111 |
112 | R, G, B = color[3].split(",")
113 | R = int(R)
114 | G = int(G)
115 | B = int(B)
116 |
117 | image_name = self.G_ref_dataset.loadImgs(ref['image_id'])[0]['file_name']
118 | image_path = os.path.join(self.IMAGE_DIR,image_name)
119 | mask = self.G_ref_dataset.getMask(ref=ref)['mask']
120 |
121 | image = Image.open(image_path).convert("RGB")
122 | image = np.asarray(image)
123 |
124 | image, mask = self._augmentation_new(image,mask)
125 |
126 | mask = (mask == 1)
127 |
128 | image_0 = Image.fromarray(image)
129 | image_1 = copy.deepcopy(image)
130 | image_1[:,:,0][mask]=self.transparency*image_1[:,:,0][mask]+(1-self.transparency)*R
131 | image_1[:,:,1][mask]=self.transparency*image_1[:,:,1][mask]+(1-self.transparency)*G
132 | image_1[:,:,2][mask]=self.transparency*image_1[:,:,2][mask]+(1-self.transparency)*B
133 | image_1 = Image.fromarray(image_1)
134 |
135 | reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item()
136 | image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
137 | image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
138 |
139 | # image_1 = Image.fromarray(image_1)
140 | sources = [{'from': 'human', 'value': '{}:
.'.format(prompt)},
141 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}]
142 |
143 | return {'input_images': [image_0], 'output_images': [image_1],
144 | 'output_cond_images': [image_0], 'conversation': sources, 'image_label_masks': [0, 1]}
--------------------------------------------------------------------------------
/locals/datasets/multimodal_tasks/llava_R.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import math
3 | import random
4 | from PIL import Image
5 | # 使用 read_parquet 加载parquet文件
6 | from pandas import read_parquet
7 | import io
8 |
9 | from constants import ALL_IMG_TOKENS_STR
10 |
11 | class LlavaRInstructDataset(Dataset):
12 | def __init__(self,
13 | path: str,
14 | instruct: bool = False,
15 | min_resize_res: int = 256,
16 | max_resize_res: int = 256,
17 | crop_res: int = 256,
18 | flip_prob: float = 0.5,
19 | sample_weight: float = 1.0,
20 | check: bool = False,
21 | output_mode: str = 'conversation',
22 | shuffle: bool = False,
23 | raw_image: bool = False,
24 | inference: bool = False,
25 | min_size: int = 50,
26 | **kwargs):
27 | self.path = path
28 | self.instruct = instruct
29 | self.inference = inference
30 | self.meta = read_parquet(path)
31 | self.min_resize_res = min_resize_res
32 | self.max_resize_res = max_resize_res
33 | self.crop_res = crop_res
34 | self.check = check
35 | self.raw_image = raw_image
36 | self.output_mode = output_mode
37 | self.shuffle = shuffle
38 | self.min_size = min_size
39 | self.flip_prob = flip_prob
40 | self.sample_weight = sample_weight
41 | print(f"LlavaR Instruct has {len(self)} samples!!")
42 |
43 | def __len__(self):
44 | return int(len(self.meta) * self.sample_weight)
45 |
46 | def get_sampler_index(self,i):
47 | if self.sample_weight >= 1:
48 | i = i % len(self.meta)
49 | else:
50 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight))
51 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder)
52 | return i
53 |
54 | def __getitem__(self, i):
55 | assert self.output_mode == 'conversation'
56 | ini_i = i
57 | i = self.get_sampler_index(i)
58 | item = self.meta.iloc[i]
59 | image_bytes = item['image']['bytes']
60 | tgt_img = [Image.open(io.BytesIO(image_bytes))]
61 | new_conversation = []
62 | assert len(item['user_texts']) == len(item['bot_texts'])
63 | for i in range(len(item['user_texts'])):
64 | if i == 0:
65 | new_conversation.append({'from':"human",'value': ALL_IMG_TOKENS_STR + item['user_texts'][i]})
66 | else:
67 | new_conversation.append({'from':"human",'value':item['user_texts'][i]})
68 | new_conversation.append({'from':"gpt",'value':item['bot_texts'][i]})
69 |
70 | return {'input_images': tgt_img, 'conversation': new_conversation,'id':f'llava_r_{ini_i}','image_label_masks': [0]}
71 |
72 | if __name__ == '__main__':
73 | test_path = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/LLaVAR-Instruct-16K/data/train-00000-of-00001-890199abde0ec4ff.parquet'
74 | test = LlavaRInstructDataset(path=test_path)
75 | print(len(test))
76 | print(test[1000])
77 |
78 |
--------------------------------------------------------------------------------
/locals/datasets/multimodal_tasks/lvis_instruct4v.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import json as js
3 | import math
4 | import random
5 | import os
6 | from PIL import Image
7 | from constants import ALL_IMG_TOKENS_STR
8 | from locals.datasets.multimodal_tasks.single_image_base import SingleImageDataset
9 | from copy import deepcopy
10 |
11 | class LVISinstruct4vDataset(SingleImageDataset):
12 | def __init__(self,
13 | path: str,
14 | image_folder: str,
15 | instruct: bool = False,
16 | min_resize_res: int = 256,
17 | max_resize_res: int = 256,
18 | crop_res: int = 256,
19 | flip_prob: float = 0.5,
20 | sample_weight: float = 1.0,
21 | check: bool = False,
22 | output_mode: str = 'text',
23 | shuffle: bool = False,
24 | raw_image: bool = False,
25 | inference: bool = False,
26 | min_size: int = 50,
27 | **kwargs):
28 | super().__init__(path, image_folder, instruct, min_resize_res, max_resize_res, crop_res, flip_prob, sample_weight, check, output_mode, shuffle, raw_image, inference, min_size, **kwargs)
29 | print(f"LVIS-instruct4v has {len(self)} samples!!")
30 |
31 | def __len__(self):
32 | return int(len(self.meta) * self.sample_weight)
33 |
34 | def __getitem__(self, i):
35 | assert self.output_mode == 'conversation'
36 |
37 | i = self.get_sampler_index(i)
38 |
39 | item = self.meta[i]
40 | if 'image' in item:
41 | image_fn = item['image']
42 | dirname = image_fn.split('/')[0]
43 | image_fn = '/'.join(image_fn.split('/')[1:])
44 | all_sub_images = []
45 | if dirname == 'coco':
46 | dirname = 'COCO2017'
47 | else:
48 | raise ValueError
49 |
50 | tgt_img = [Image.open(os.path.join(self.image_folder, dirname, image_fn)).convert('RGB')]
51 |
52 | new_conversation = deepcopy(item['conversations'])
53 | for conv in new_conversation:
54 | if '' in conv['value']:
55 | conv['value'] = conv['value'].replace('', ALL_IMG_TOKENS_STR)
56 | return {'input_images': tgt_img + all_sub_images, 'conversation': new_conversation,'id':'LVIS_'+item['id'],'image_label_masks': [0]}
57 | else:
58 | tgt_img = None
59 | return {'conversation': item['conversations'],'id':'LVIS_'+item['id']}
60 |
61 | if __name__ == '__main__':
62 | test_path = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/LVIS-Instruct4V/lvis_instruct4v_220k.json'
63 | image_folder = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/'
64 | test = LVISinstruct4vDataset(path=test_path, image_folder=image_folder, output_mode = 'conversation')
65 | print(len(test))
66 | print(test[1000])
67 |
68 |
69 |
--------------------------------------------------------------------------------
/locals/datasets/multimodal_tasks/m3it.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import json as js
3 | import math
4 | import random
5 | import os
6 | from PIL import Image
7 | from typing import List
8 | from constants import *
9 | import re, copy
10 | from base64 import b64encode, b64decode
11 | import io
12 | from datasets import load_dataset, concatenate_datasets
13 |
14 | def byte2image(byte_data):
15 | """
16 | convert byte to PIL image
17 | """
18 | if isinstance(byte_data, str):
19 | byte_data = b64decode(byte_data)
20 | image = Image.open(io.BytesIO(byte_data))
21 | return image
22 |
23 | class M3ITDataset(Dataset):
24 |
25 | def __init__(self,
26 | path: str,
27 | dataset_names: List[str],
28 | split: str = 'train',
29 | crop_res: int = 256,
30 | flip_prob: float = 0.5,
31 | sample_weight: float = 1.0,
32 | check: bool = False,
33 | output_mode: str = 'text',
34 | shuffle: bool = False,
35 | raw_image: bool = False,
36 | inference: bool = False,
37 | min_size: int = 50,
38 | **kwargs):
39 | # load from json ../datasets/gqa-inpaint/meta_info.json
40 | self.path = path
41 | self.inference = inference
42 | tmp_ds = []
43 | for name in dataset_names:
44 | print(name)
45 | tmp_ds.append(load_dataset(os.path.join(path, name))[split])
46 | self.meta = concatenate_datasets(tmp_ds)
47 | self.crop_res = crop_res
48 | self.check = check
49 | self.raw_image = raw_image
50 | self.output_mode = output_mode
51 | self.shuffle = shuffle
52 | self.min_size = min_size
53 |
54 | self.flip_prob = flip_prob
55 | self.sample_weight = sample_weight
56 | print(f"LLaVA Academic has {len(self)} samples!!")
57 |
58 | def __len__(self):
59 | return int(len(self.meta) * self.sample_weight)
60 |
61 | def __getitem__(self, i):
62 | assert self.output_mode == 'conversation'
63 |
64 | if self.sample_weight >= 1:
65 | i = i % len(self.meta)
66 | else:
67 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight))
68 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder)
69 |
70 | item = self.meta[i]
71 | image = item['image_base64_str']
72 | image = byte2image(image[0]).convert('RGB')
73 |
74 | if len(item['inputs']):
75 | prompt = item['instruction'] + ' ' + item['inputs']
76 | else:
77 | prompt = item['instruction']
78 |
79 | prompt = ALL_IMG_TOKENS_STR + ' ' + prompt
80 | conversation = [
81 | {'from': 'human', 'value': prompt},
82 | {'from': 'gpt', 'value': item['outputs']}
83 | ]
84 | return {'input_images': [image], 'conversation': conversation, 'image_label_masks': [0]}
85 |
86 |
--------------------------------------------------------------------------------
/locals/datasets/multimodal_tasks/sharegpt4v.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from locals.datasets.multimodal_tasks.single_image_base import SingleImageDataset
4 | from copy import deepcopy
5 | from constants import ALL_IMG_TOKENS_STR
6 |
7 | class Sharegpt4vDataset(SingleImageDataset):
8 | def __init__(self,
9 | path: str,
10 | image_folder: str,
11 | llava_pretrain_folder: str,
12 | coco_img_folder: str,
13 | instruct: bool = False,
14 | min_resize_res: int = 256,
15 | max_resize_res: int = 256,
16 | crop_res: int = 256,
17 | flip_prob: float = 0.5,
18 | sample_weight: float = 1.0,
19 | check: bool = False,
20 | output_mode: str = 'text',
21 | shuffle: bool = False,
22 | raw_image: bool = False,
23 | inference: bool = False,
24 | min_size: int = 50,
25 | **kwargs):
26 | super().__init__(path, image_folder, instruct, min_resize_res, max_resize_res, crop_res, flip_prob, sample_weight, check, output_mode, shuffle, raw_image, inference, min_size, **kwargs)
27 | self.llava_pretrain_folder = llava_pretrain_folder
28 | self.coco_img_folder = coco_img_folder
29 | print(f"Sharegpt4vDataset has {len(self)} samples!!")
30 |
31 | def __len__(self):
32 | return int(len(self.meta) * self.sample_weight)
33 |
34 | def __getitem__(self, i):
35 | assert self.output_mode == 'conversation'
36 |
37 | i = self.get_sampler_index(i)
38 |
39 | item = self.meta[i]
40 | if 'image' in item:
41 | image_fn = item['image']
42 | sub_dir = image_fn.split('/')[0]
43 |
44 | if sub_dir == 'llava':
45 | image_folder = self.llava_pretrain_folder
46 | image_fn = '/'.join(image_fn.split('/')[3:])
47 | elif sub_dir == 'coco':
48 | image_folder = self.coco_img_folder
49 | image_fn = '/'.join(image_fn.split('/')[1:])
50 | else:
51 | image_folder = self.image_folder
52 | tgt_img = [Image.open(os.path.join(image_folder, image_fn)).convert('RGB')]
53 |
54 | new_conversation = deepcopy(item['conversations'])
55 | for conv in new_conversation:
56 | if '' in conv['value']:
57 | conv['value'] = conv['value'].replace('', ALL_IMG_TOKENS_STR)
58 |
59 | return {'input_images': tgt_img, 'conversation': new_conversation,'id':'Sharegpt4v_'+item['id'],'image_path':item['image'],'image_label_masks': [0]}
60 | else:
61 | tgt_img = None
62 | return {'conversation': item['conversations'],'id':'Sharegpt4v_'+item['id'],'image_path':item['image']}
63 |
64 | if __name__ == '__main__':
65 | from tqdm import tqdm
66 | test_path = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/ShareGPT4V/sharegpt4v_instruct_gpt4-vision_cap100k.json'
67 | image_folder = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/ShareGPT4V/data/'
68 | coco_image_folder = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/COCO2017'
69 | llava_pretrain_folder = '/mnt/bn/luoruipu-disk/meta_data/pretrain_data/LLaVA-Pretrain/'
70 | test = Sharegpt4vDataset(path=test_path, image_folder=image_folder, llava_pretrain_folder=llava_pretrain_folder,coco_img_folder=coco_image_folder, output_mode = 'conversation')
71 | print(len(test))
72 | for i in tqdm(range(len(test))):
73 | test[i]
74 |
--------------------------------------------------------------------------------
/locals/datasets/multimodal_tasks/single_image_base.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import json as js
3 | import math
4 | import random
5 | import os
6 | from PIL import Image
7 |
8 | class SingleImageDataset(Dataset):
9 |
10 | def __init__(self,
11 | path: str,
12 | image_folder: str,
13 | instruct: bool = False,
14 | min_resize_res: int = 256,
15 | max_resize_res: int = 256,
16 | crop_res: int = 256,
17 | flip_prob: float = 0.5,
18 | sample_weight: float = 1.0,
19 | check: bool = False,
20 | output_mode: str = 'text',
21 | shuffle: bool = False,
22 | raw_image: bool = False,
23 | inference: bool = False,
24 | min_size: int = 50,
25 | **kwargs):
26 | # load from json ../datasets/gqa-inpaint/meta_info.json
27 | self.path = path
28 | self.instruct = instruct
29 | self.inference = inference
30 | self.meta = js.load(open(path))
31 | self.image_folder = image_folder
32 | self.min_resize_res = min_resize_res
33 | self.max_resize_res = max_resize_res
34 | self.crop_res = crop_res
35 | self.check = check
36 | self.raw_image = raw_image
37 | self.output_mode = output_mode
38 | self.shuffle = shuffle
39 | self.min_size = min_size
40 | self.flip_prob = flip_prob
41 | self.sample_weight = sample_weight
42 |
43 | def __len__(self):
44 | return int(len(self.meta) * self.sample_weight)
45 |
46 | def get_sampler_index(self,i):
47 | if self.sample_weight >= 1:
48 | i = i % len(self.meta)
49 | else:
50 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight))
51 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder)
52 | return i
53 |
54 | def __getitem__(self, i):
55 | raise NotImplementedError
--------------------------------------------------------------------------------
/locals/datasets/prompts/prompt_captioning_long.txt:
--------------------------------------------------------------------------------
1 | Write a description of the image, capturing its main components, the relationships between them, and any notable details.
2 | Create a caption that accurately describes the elements in the image provided.
3 | Write a comprehensive, yet brief, description of the image.
4 | Describe the image in a clear and detailed manner.
5 | For the given image, describe the visual content, try to include every components.
6 | Generate a detailed caption for the picture.
7 | Write a detailed and informative description that highlights the primary subjects and actions occurring in the given image.
8 | Write a clear description of the image, make sure the key features are well covered.
9 | Offer a comprehensive explanation of the picture presented.
--------------------------------------------------------------------------------
/locals/datasets/prompts/prompt_captioning_short.txt:
--------------------------------------------------------------------------------
1 | Describe the image briefly.
2 | Write a succinct description of the image, capturing its main components, the relationships between them, and any notable details.
3 | Create a concise caption that accurately describes the main elements in the image provided.
4 | Write a brief, yet comprehensive, description of the image.
5 | Describe the image in a clear and concise manner.
6 | For the given image, provide a one-sentence summary that captures the most important details.
7 | Generate a short caption for the picture.
8 | Write a short and informative description that highlights the primary subjects and actions occurring in the given image.
9 | Provide a concise and informative caption for the image, focusing on the primary subjects.
10 | Write a clear description of the image, make sure the key features are well covered.
11 | Offer a succinct explanation of the picture presented.
--------------------------------------------------------------------------------
/locals/datasets/prompts/prompt_kosmosg.txt:
--------------------------------------------------------------------------------
1 | Help me to generate this image: {desc}
2 | Generate image according to the following information: {desc}
3 | Based on the provided description and corresponding images of objects: {desc}, generate a complete image.
4 | Create an image of {desc}
5 | I will provide you with descriptions of the images and images of the objects. help me reconstruct the entire picture: {desc}
--------------------------------------------------------------------------------
/locals/datasets/prompts/prompt_ref_i2t.txt:
--------------------------------------------------------------------------------
1 | Briefly describe the region {region}.
2 | Provide a description for the region {region} to distinguish it with other regions.
3 | Express what is showed in this region: {region}.
4 | What object is in this image segment {region}.
5 | How will you depict this visual region {region} with natural language?
6 | What is in {region}?
7 | Can you describe the region {region} for me?
--------------------------------------------------------------------------------
/locals/datasets/prompts/prompt_ref_t2i.txt:
--------------------------------------------------------------------------------
1 | Locate the region in the image describing: {phrase}.
2 | Find the region or regions that correspond to {phrase}.
3 | Can you find {phrase} in the image?
4 | Please detect {phrase} in this picture.
5 | In this image, which object matches the description "{phrase}"?
6 | Which region in the figure best corresponds to the following description: {phrase}?
7 | Which part of the illustration best matches {phrase}?
8 | Please locate the area described as {phrase} in the picture.
9 | Identify the region in the image referred as {phrase}.
10 | Please help me find {phrase} in the picture.
--------------------------------------------------------------------------------
/locals/datasets/prompts/prompt_txt2img.txt:
--------------------------------------------------------------------------------
1 | generate an image with caption:
2 | can you give me the image with caption:
3 | help me to generate this image:
4 | generate image with according to caption:
5 | according to caption, generate image:
6 | an image with caption:
7 | can you visualize this caption:
8 | Please generate an image corresponding to my description:
9 |
--------------------------------------------------------------------------------
/locals/datasets/text/sharegpt.py:
--------------------------------------------------------------------------------
1 | from locals.datasets.text.text_data_base import TextDataset
2 |
3 | class ShareGPTDataset(TextDataset):
4 | def __init__(self,
5 | path: str,
6 | instruct: bool = False,
7 | sample_weight: float = 1,
8 | output_mode: str = 'conversation',
9 | shuffle: bool = False,
10 | inference: bool = False,
11 | **kwargs):
12 | path = [path]
13 | super().__init__(path, instruct, sample_weight, output_mode, shuffle, inference, **kwargs)
14 | print(f"ShareGPTDataset has {len(self)} samples!!")
15 |
16 | def __getitem__(self, i):
17 | assert self.output_mode == 'conversation'
18 | i = self.get_sampler_index(i)
19 | item = self.meta[i]
20 | return {'conversation': item['conversations'],'id':item['id']}
21 |
22 | class ShareGPTCodeDataset(TextDataset):
23 | def __init__(self,
24 | path: str,
25 | instruct: bool = False,
26 | sample_weight: float = 1,
27 | output_mode: str = 'conversation',
28 | shuffle: bool = False,
29 | inference: bool = False,
30 | **kwargs):
31 | path = [path]
32 | super().__init__(path, instruct, sample_weight, output_mode, shuffle, inference, **kwargs)
33 | print(f"ShareGPTCodeDataset has {len(self)} samples!!")
34 |
35 | def __getitem__(self, i):
36 | assert self.output_mode == 'conversation'
37 | i = self.get_sampler_index(i)
38 | item = self.meta[i]
39 | return {'conversation': item['conversations'],'id':'sharegpt_code_'+item['id']}
40 |
41 |
42 |
43 | if __name__ == '__main__':
44 | test_path = '/mnt/bn/yangmin-priv/luoruipu/data/text-dataset/sharegpt_184k_no_repeat.json'
45 | test = ShareGPTDataset(path=test_path)
46 | print(len(test))
47 | print(test[100000])
--------------------------------------------------------------------------------
/locals/datasets/text/text_data_base.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset
2 | import json as js
3 | import math
4 | import random
5 | from tqdm import tqdm
6 |
7 | class TextDataset(Dataset):
8 |
9 | def __init__(self,
10 | pathlist: list,
11 | instruct: bool = False,
12 | sample_weight: float = 1.0,
13 | output_mode: str = 'text',
14 | shuffle: bool = False,
15 | inference: bool = False,
16 | **kwargs):
17 | # load from json ../datasets/gqa-inpaint/meta_info.json
18 | self.pathlist = pathlist
19 | self.instruct = instruct
20 | self.inference = inference
21 | self.meta = []
22 | for path in tqdm(pathlist):
23 | if path.endswith('json'):
24 | self.meta += js.load(open(path))
25 | elif path.endswith('jsonl'):
26 | with open(path) as f:
27 | for line in f:
28 | try:
29 | self.meta.append(js.loads(line))
30 | except:
31 | continue
32 | else:
33 | raise ValueError('json or jsonl file type is supported. ')
34 | self.output_mode = output_mode
35 | self.shuffle = shuffle
36 | self.sample_weight = sample_weight
37 |
38 | def __len__(self):
39 | return int(len(self.meta) * self.sample_weight)
40 |
41 | def get_sampler_index(self,i):
42 | if self.sample_weight >= 1:
43 | i = i % len(self.meta)
44 | else:
45 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight))
46 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder)
47 | return i
48 |
49 | def __getitem__(self, i):
50 | raise NotImplementedError
--------------------------------------------------------------------------------
/locals/datasets/text/txt_cot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import copy
4 | import json
5 | import math
6 | from pathlib import Path
7 | from tkinter import E
8 | from typing import Any
9 |
10 | import numpy as np
11 | import torch
12 | import torchvision
13 | import pandas as pd
14 | from einops import rearrange
15 | from PIL import Image
16 | from torch.utils.data import Dataset
17 | from utils.util import byte2image
18 | from ..utils.box_utils import *
19 | from datasets import load_dataset
20 |
21 | from constants import *
22 |
23 | class CoTCollectionDataset(Dataset):
24 | def __init__(self,
25 | path:str = None,
26 | sample_weight: float=1.0,
27 | ):
28 | self.meta = load_dataset(path)['train']
29 | self.sample_weight = sample_weight
30 |
31 | def __len__(self):
32 | return int(len(self.meta) * self.sample_weight)
33 |
34 | def __getitem__(self, i):
35 |
36 | if self.sample_weight >= 1:
37 | i = i % len(self.meta)
38 | else:
39 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight))
40 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder)
41 |
42 | item = self.meta[i]
43 | question = item['source'] + ' Answer the question and include the reasoning proess.'
44 | thought = item['rationale']
45 | answer = item['target']
46 |
47 | conv = [{'from': 'human', 'value': question},
48 | {'from': 'gpt', 'value': thought},
49 | {'from': 'human', 'value': 'What is your final answer?'},
50 | {'from': 'gpt', 'value': answer}]
51 | return {'conversation': conv}
--------------------------------------------------------------------------------
/locals/datasets/text/ultrachat.py:
--------------------------------------------------------------------------------
1 | from locals.datasets.text.text_data_base import TextDataset
2 | from pathlib import Path
3 | class UltraChatDataset(TextDataset):
4 | def __init__(self,
5 | path: str,
6 | instruct: bool = False,
7 | sample_weight: float = 1,
8 | output_mode: str = 'conversation',
9 | shuffle: bool = False,
10 | inference: bool = False,
11 | **kwargs):
12 | path = [str(p)for p in Path(path).glob('*.jsonl')]
13 | super().__init__(path, instruct, sample_weight, output_mode, shuffle, inference, **kwargs)
14 | self.filter_dataset()
15 | print(f"UltraChat Dataset has {len(self)} samples!!")
16 |
17 | def filter_dataset(self,):
18 | new_data = []
19 | for item in self.meta:
20 | try:
21 | assert len(item['data'])%2==0
22 | new_data.append(item)
23 | except:
24 | continue
25 | self.meta = new_data
26 |
27 | def __getitem__(self, i):
28 | assert self.output_mode == 'conversation'
29 | i = self.get_sampler_index(i)
30 | item = self.meta[i]
31 | conversations = []
32 | for i in range(0,len(item['data']),2):
33 | conversations.append({'from':'human','value':item['data'][i]})
34 | conversations.append({'from':'gpt','value':item['data'][i+1]})
35 | return {'conversation': conversations,'id':'UltraChat'+str(item['id'])}
36 |
37 | class UltraChatJsonDataset(TextDataset):
38 | def __init__(self,
39 | path: str,
40 | instruct: bool = False,
41 | sample_weight: float = 1,
42 | output_mode: str = 'conversation',
43 | shuffle: bool = False,
44 | inference: bool = False,
45 | **kwargs):
46 | path = [path]
47 | super().__init__(path, instruct, sample_weight, output_mode, shuffle, inference, **kwargs)
48 | self.filter_dataset()
49 | print(f"UltraChat Dataset has {len(self)} samples!!")
50 |
51 |
52 | def filter_dataset(self,):
53 | new_data = []
54 | length_list = []
55 | for item in self.meta:
56 | try:
57 | assert len(item['conversations'])%2==0 and len(item['conversations'])!=0
58 | new_data.append(item)
59 | except:
60 | continue
61 | finally:
62 | cur_len = sum(len(conv['value'].split()) for conv in item['conversations'])
63 | cur_len = cur_len if 'image' in item else -cur_len
64 | length_list.append(cur_len)
65 | self.meta = new_data
66 | self.length = length_list
67 |
68 | def __getitem__(self, i):
69 | assert self.output_mode == 'conversation'
70 | item = self.meta[i]
71 | return {'conversation': item['conversations'],'id':'UltraChat'+str(item['id'])}
72 |
73 |
74 | if __name__ == '__main__':
75 | ultrachat_path = '/mnt/bn/yangmin-priv/luoruipu/data/text-dataset/ultrachat/'
76 | test = UltraChatDataset(path=ultrachat_path)
77 | print(len(test))
78 | print(test[0])
--------------------------------------------------------------------------------
/locals/datasets/text2image/kosmosg.py:
--------------------------------------------------------------------------------
1 | from email.policy import default
2 | from torch.utils.data import Dataset
3 | import json as js
4 | import math
5 | import random
6 | import os
7 | from PIL import Image
8 | from collections import defaultdict
9 | from constants import *
10 |
11 | def load_prompt(fn):
12 | prompts = []
13 | with open(fn) as f:
14 | line=f.readline()
15 | while line:
16 | line=line.strip('\n')
17 | prompts.append(line)
18 | line=f.readline()
19 | return prompts
20 |
21 | class KosMosGDataset(Dataset):
22 | def __init__(self,
23 | path: str,
24 | image_folder: str,
25 | instruct: bool = False,
26 | min_resize_res: int = 256,
27 | max_resize_res: int = 256,
28 | crop_res: int = 256,
29 | flip_prob: float = 0.5,
30 | sample_weight: float = 1.0,
31 | check: bool = False,
32 | output_mode: str = 'conversation',
33 | raw_image: bool = False,
34 | object_type: str = 'mask',
35 | **kwargs):
36 | # load from json ../datasets/gqa-inpaint/meta_info.json
37 | self.path = path
38 | self.instruct = instruct
39 | self.meta = js.load(open(path))
40 | self.image_folder = image_folder
41 | self.min_resize_res = min_resize_res
42 | self.max_resize_res = max_resize_res
43 | self.crop_res = crop_res
44 | self.check = check
45 | self.raw_image = raw_image
46 | self.output_mode = output_mode
47 | self.flip_prob = flip_prob
48 | self.sample_weight = sample_weight
49 | self.object_type = object_type
50 | self.prompts = load_prompt('locals/datasets/prompts/prompt_kosmosg.txt')
51 | print(f"KosMos-G has {len(self)} samples!!")
52 |
53 | def __len__(self):
54 | return int(len(self.meta) * self.sample_weight)
55 |
56 | def get_image_fn(self, image_id):
57 | return os.path.join(self.image_folder, image_id)
58 |
59 | def __getitem__(self, i):
60 |
61 | if self.sample_weight >= 1:
62 | i = i % len(self.meta)
63 | else:
64 | remainder = math.ceil(i / self.sample_weight - int(i / self.sample_weight))
65 | i = int(i / self.sample_weight) + random.randint(0, int(1 / self.sample_weight) - 1 + remainder)
66 |
67 | item = self.meta[i]
68 | caption = item['source'].replace('', ' '+ALL_IMG_TOKENS_STR)
69 | tgt_img = Image.open(item['target']).convert('RGB')
70 | prompt = random.choice(self.prompts)
71 |
72 | # find the caption
73 | if self.object_type == 'random':
74 | prob = random.random()
75 | if prob > 0.5:
76 | object_type = 'mask'
77 | else:
78 | object_type = 'obj'
79 | else:
80 | object_type = self.object_type
81 | if object_type == 'mask':
82 | object_images = [Image.open(img).convert('RGB') for img in item['masks']]
83 | elif object_type == 'obj':
84 | object_images = [Image.open(img).convert('RGB') for img in item['objects']]
85 | else:
86 | raise ValueError
87 |
88 | if self.output_mode == 'conversation':
89 | sources = [{'from': 'human', 'value': prompt.format(desc=caption)},
90 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}]
91 | return {'input_images': object_images+[tgt_img], 'conversation': sources, 'image_label_masks': [0]*len(object_images)+[1]}
92 | elif self.output_mode == 'text':
93 | text = caption + ' ' + '{}.'.format(ALL_IMG_TOKENS_STR)
94 | return {'input_images': object_images+[tgt_img], 'text': text, 'image_label_masks': [0]*len(object_images)+[1]}
--------------------------------------------------------------------------------
/locals/datasets/text2image/midjourney.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from locals.datasets.multimodal_tasks.single_image_base import SingleImageDataset
4 | from constants import ALL_IMG_TOKENS_STR
5 | import random
6 | class MidjourneyDataset(SingleImageDataset):
7 | def __init__(self,
8 | path: str,
9 | image_folder: str,
10 | instruct_prompt_path: str = None,
11 | instruct: bool = False,
12 | min_resize_res: int = 256,
13 | max_resize_res: int = 256,
14 | crop_res: int = 256,
15 | flip_prob: float = 0.5,
16 | sample_weight: float = 1.0,
17 | check: bool = False,
18 | shuffle: bool = False,
19 | raw_image: bool = False,
20 | inference: bool = False,
21 | min_size: int = 50,
22 | **kwargs):
23 | output_mode = 'conversation'
24 | super().__init__(path, image_folder, instruct, min_resize_res, max_resize_res, crop_res, flip_prob, sample_weight, check, output_mode, shuffle, raw_image, inference, min_size, **kwargs)
25 | self.prompt_list = open(instruct_prompt_path).readlines()
26 | print(f"MidjourneyDataset has {len(self)} samples!!")
27 |
28 | def __len__(self):
29 | return int(len(self.meta) * self.sample_weight)
30 |
31 | def __getitem__(self, i):
32 | i = self.get_sampler_index(i)
33 |
34 | item = self.meta[i]
35 |
36 | image_fn = item['id']
37 |
38 | tgt_img = [Image.open(os.path.join(self.image_folder, image_fn+'.png')).convert('RGB')]
39 | instruct = random.choice(self.prompt_list)
40 | new_conversation = [{'from': 'human', 'value': instruct + ','.join(item['event']['textPrompt'])},
41 | {'from': 'gpt', 'value': '{}.'.format(ALL_IMG_TOKENS_STR)}]
42 |
43 | return {'output_images': tgt_img, 'conversations': new_conversation,'id':'midjourney_'+item['id'],'image_label_masks': [1]}
44 |
45 |
46 | if __name__ == '__main__':
47 | from tqdm import tqdm
48 | test_path = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/midjourney/midjourney_sample_50k_new.json'
49 | image_folder = '/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/midjourney/images'
50 | instruct_prompt_path = '/mnt/bn/yangmin-priv/luoruipu/code/Edit-GPT4/locals/datasets/prompts/prompt_txt2img.txt'
51 | test = MidjourneyDataset(path=test_path, image_folder=image_folder,instruct_prompt_path=instruct_prompt_path)
52 | print(test[0])
53 | # import json as js
54 | # data = js.load(open(test_path))
55 | # new_data = []
56 | # for item in tqdm(data):
57 | # image_fn = item['id']
58 | # try:
59 | # img = Image.open(os.path.join(image_folder, image_fn+'.png'))
60 | # new_data.append(item)
61 | # except:
62 | # continue
63 | # print(len(new_data))
64 | # js.dump(new_data, open('/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/midjourney/midjourney_filtered.json','w'))
--------------------------------------------------------------------------------
/locals/datasets/utils/box_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import WindowsPath
2 | from constants import *
3 | from utils.eval_util import extract_all_box_str
4 |
5 | def process_thought(thought, mistral=False):
6 | new_thought = thought.replace(' ', '')
7 | # if ' ' in new_thought:
8 |
9 | new_thought = new_thought.replace(' ', '' + ALL_IMG_TOKENS_STR).replace('', '')
10 | all_box = extract_all_box_str(thought, mistral)
11 | return new_thought, all_box
12 |
13 |
14 | def box2str(box, mode='special_tokens', prec=2, space=False):
15 | if mode == 'special_tokens':
16 | # using tokens to represent the locations
17 | num_blocks = len(ALL_LOC_TOKENS)
18 | size_per_block = 1 / num_blocks
19 | block_no = [int(c / size_per_block) if c < 1.0 else len(ALL_LOC_TOKENS)-1 for c in box]
20 | return ''.join([ALL_LOC_TOKENS[i] for i in block_no])
21 | elif mode == 'text':
22 | # using text to represent the box
23 | if space:
24 | sep = ', '
25 | else:
26 | sep = ','
27 | tmp_format = sep.join(['{' + ':.{}f'.format(prec)+'}']*4)
28 | a_box = [float(o) for o in box]
29 | return tmp_format.format(*a_box)
30 | else:
31 | raise NotImplementedError
32 |
33 | def allbox2str(objects):
34 | s = []
35 | for obj in objects:
36 | s.append('{}: [{}]'.format(obj['class'], box2str(obj['bbox'], 'text', 3, True)))
37 | return ', '.join(s)
38 |
39 | def reshape_box(image, box):
40 | width, height = image.size
41 | abs_box = [c*width if i%2==0 else c*height for i,c in enumerate(box)]
42 | if width == height:
43 | return box
44 | elif width > height:
45 | abs_box[1] += (width - height) // 2
46 | abs_box[3] += (width - height) // 2
47 | max_size = width
48 | else:
49 | abs_box[0] += (height - width) // 2
50 | abs_box[2] += (height - width) // 2
51 | max_size = height
52 | norm_box = [c/max_size for c in abs_box]
53 | return norm_box
54 |
55 | def reshape_box_reverse(image, box):
56 | width, height = image.size
57 | max_side = max(width, height)
58 | abs_box = [c*max_side if i%2==0 else c*max_side for i,c in enumerate(box)]
59 | if width == height:
60 | return box
61 | elif width > height:
62 | abs_box[1] -= (width - height) // 2
63 | abs_box[3] -= (width - height) // 2
64 | max_size = width
65 | else:
66 | abs_box[0] -= (height - width) // 2
67 | abs_box[2] -= (height - width) // 2
68 | max_size = height
69 | norm_box = [c/width if i%2 ==0 else c/height for i,c in enumerate(abs_box)]
70 | return norm_box
71 |
72 | def resize_image_to_square(image):
73 | width, height = image.size
74 | max_side = max(width, height)
75 | image = image.resize((max_side, max_side))
76 | return image
77 |
78 |
79 | def expand2square_fn(pil_img, background_color):
80 | from PIL import Image
81 | width, height = pil_img.size
82 | if width == height:
83 | return pil_img
84 | elif width > height:
85 | result = Image.new(pil_img.mode, (width, width), background_color)
86 | result.paste(pil_img, (0, (width - height) // 2))
87 | return result
88 | else:
89 | result = Image.new(pil_img.mode, (height, height), background_color)
90 | result.paste(pil_img, ((height - width) // 2, 0))
91 | return result
--------------------------------------------------------------------------------
/locals/datasets/utils/zip_manager.py:
--------------------------------------------------------------------------------
1 | import zipfile
2 | import os.path as osp
3 | # import lmdb
4 | import logging
5 | from PIL import Image
6 | import pickle
7 | import io
8 | import glob
9 | import os
10 | from pathlib import Path
11 | import time
12 | from threading import Thread
13 | from PIL import ImageFile
14 | ImageFile.LOAD_TRUNCATED_IMAGES = True
15 |
16 | home = str(Path.home())
17 | abs_blob_path=os.path.realpath("/mnt/blob/")
18 | CACHE_FOLDER=os.path.join(home,"caching")
19 | USE_CACHE=True
20 |
21 | def norm(path):
22 | assert "*" not in path
23 | return os.path.realpath(os.path.abspath(path))
24 |
25 | def in_blob(file):
26 | if abs_blob_path in file:
27 | return True
28 | else:
29 | return False
30 |
31 | def map_name(file):
32 | path=norm(file)
33 | path=path.lstrip(abs_blob_path+"/")
34 | path=path.replace("/","_")
35 | assert len(path)<250
36 | return path
37 |
38 |
39 | def preload(db,sync=False):
40 | if sync:
41 | db.initialize()
42 | else:
43 | p = Thread(target=db.initialize)
44 | p.start()
45 |
46 | def get_keys_from_lmdb(db):
47 | with db.begin(write=False) as txn:
48 | return list(txn.cursor().iternext(values=False))
49 |
50 | def decode_img(byteflow):
51 | try:
52 | img=Image.open(io.BytesIO(byteflow)).convert("RGB")
53 | img.load()
54 | except:
55 | img = Image.open("white.jpeg").convert("RGB")
56 | img.load()
57 | return img
58 |
59 | def decode_text(byteflow):
60 | return pickle.loads(byteflow)
61 |
62 | decode_funcs={
63 | "image": decode_img,
64 | "text": decode_text
65 | }
66 |
67 |
68 | class ZipManager:
69 | def __init__(self, zip_path,data_type,prefix=None) -> None:
70 | self.decode_func=decode_funcs[data_type]
71 | self.zip_path=zip_path
72 | self._init=False
73 | preload(self)
74 |
75 | def deinitialze(self):
76 | self.zip_fd.close()
77 | del self.zip_fd
78 | self._init = False
79 |
80 | def initialize(self,close=True):
81 | self.zip_fd = zipfile.ZipFile(self.zip_path, mode="r")
82 | if not hasattr(self,"_keys"):
83 | self._keys = self.zip_fd.namelist()
84 | self._init = True
85 | if close:
86 | self.deinitialze()
87 |
88 | @property
89 | def keys(self):
90 | while not hasattr(self,"_keys"):
91 | time.sleep(0.1)
92 | return self._keys
93 |
94 | def get(self, name):
95 | if not self._init:
96 | self.initialize(close=False)
97 | byteflow = self.zip_fd.read(name)
98 | return self.decode_func(byteflow)
99 |
100 |
101 | class MultipleZipManager:
102 | def __init__(self, files: list, data_type, sync=True):
103 | self.files = files
104 | self._is_init = False
105 | self.data_type=data_type
106 | if sync:
107 | print("sync",files)
108 | self.initialize()
109 | else:
110 | print("async",files)
111 | preload(self)
112 | print("initialize over")
113 |
114 |
115 | def initialize(self):
116 | self.mapping={}
117 | self.managers={}
118 | for file in self.files:
119 | manager = ZipManager(file, self.data_type)
120 | self.managers[file]=manager
121 |
122 | for file,manager in self.managers.items():
123 | print(file)
124 | # print("loading")
125 | logging.info(f"{file} loading")
126 | keys=manager.keys
127 | for key in keys:
128 | self.mapping[key]=file
129 | logging.info(f"{file} loaded, size = {len(keys)}")
130 | print("loaded")
131 |
132 | self._keys=list(self.mapping.keys())
133 | self._is_init=True
134 |
135 | @property
136 | def keys(self):
137 | while not self._is_init:
138 | time.sleep(0.1)
139 | return self._keys
140 |
141 | def get(self, name):
142 | data = self.managers[self.mapping[name]].get(name)
143 | return data
--------------------------------------------------------------------------------
/model/front_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 | from .Qformer import BertConfig, BertLMHeadModel
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"front_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 | def forward(self, x):
29 | x = self.pre_norm(x)
30 | return x + self.proj(x)
31 |
32 | def init_Qformer(num_query_token, vision_width, cross_attention_freq=2):
33 | encoder_config = BertConfig.from_pretrained("bert-base-uncased")
34 | encoder_config.encoder_width = vision_width
35 | # insert cross-attention layer every other block
36 | encoder_config.add_cross_attention = True
37 | encoder_config.cross_attention_freq = cross_attention_freq
38 | encoder_config.query_length = num_query_token
39 | Qformer = BertLMHeadModel(config=encoder_config)
40 | query_tokens = nn.Parameter(
41 | torch.zeros(1, num_query_token, encoder_config.hidden_size)
42 | )
43 | query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
44 | return Qformer, query_tokens
45 |
46 | def build_front_projector(config, delay_load=False, **kwargs):
47 | projector_type = getattr(config, 'front_projector_type', 'linear')
48 |
49 | if projector_type == 'linear':
50 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
51 | elif projector_type == 'q_former':
52 | Qformer, query_tokens = init_Qformer(
53 | kwargs['num_query_token'], kwargs['visual_encoder'].num_features
54 | )
55 | Qformer.cls = None
56 | Qformer.bert.embeddings.word_embeddings = None
57 | Qformer.bert.embeddings.position_embeddings = None
58 | for layer in Qformer.bert.encoder.layer:
59 | layer.output = None
60 | layer.intermediate = None
61 | checkpoint = torch.load(config.front_projector, map_location="cpu")
62 | state_dict = checkpoint["model"]
63 | new_state_dict = {}
64 | for key in state_dict:
65 | if 'Qformer' in key:
66 | new_state_dict[key.split('Qformer.')[1]] = state_dict[key]
67 | Qformer.load_state_dict(new_state_dict)
68 | return Qformer, query_tokens
69 |
70 | else:
71 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
72 | if mlp_gelu_match:
73 | mlp_depth = int(mlp_gelu_match.group(1))
74 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
75 | for _ in range(1, mlp_depth):
76 | modules.append(nn.GELU())
77 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
78 | return nn.Sequential(*modules)
79 |
80 | if projector_type == 'identity':
81 | return IdentityMap()
82 |
83 | raise ValueError(f'Unknown projector type: {projector_type}')
84 |
--------------------------------------------------------------------------------
/model/load_model.py:
--------------------------------------------------------------------------------
1 | from genericpath import samestat
2 | from torch.utils.data import ConcatDataset, DataLoader
3 | from typing import Optional, Dict
4 | from dataclasses import dataclass, field
5 | from locals.datasets import SFT_DataCollator, WrappedDataset
6 | from lightning.pytorch import seed_everything
7 | from torchvision import transforms
8 | from constants import *
9 | from PIL import Image
10 | from torch.utils.data.distributed import DistributedSampler
11 | from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
12 |
13 | from locals.datasets.preprocessor import VoCoT_InputProcessor
14 | from omegaconf import OmegaConf
15 | from utils.util import instantiate_from_config
16 | from model.language_model.volcano_llama import VolCanoLlamaForCausalLM,VolCanoConfig
17 | from model.language_model.volcano_mistral import VolCanoMistralForCausalLM, VolCanoMistralConfig
18 | from transformers import LlamaTokenizer, AutoTokenizer
19 | import transformers
20 | from peft import PeftConfig, PeftModel
21 | from argparse import ArgumentParser
22 | import os
23 | import torch.distributed as dist
24 | from utils.logger import setup_logger
25 | import json
26 | import tqdm
27 |
28 | def rank0_print(args, res):
29 | if args.local_rank==0 or args.local_rank == -1:
30 | print(res)
31 |
32 | def get_output_name(args, mid_output=True):
33 | if mid_output:
34 | return os.path.join(args.output_dir,
35 | '{}_rank{}.json'.format(args.dataset_name, args.local_rank))
36 | else:
37 | return os.path.join(args.output_dir,
38 | '{}.json'.format(args.dataset_name))
39 |
40 | def get_all_output_names(args):
41 | return [os.path.join(args.output_dir,
42 | '{}_rank{}.json'.format(args.dataset_name, r)) for r in range(args.n_gpus)]
43 |
44 | class CLIPTransform:
45 | def __init__(self, transform, square_size=None):
46 | self.transform = transform
47 | self.square_size = square_size
48 | self.image_mean = transform.image_mean
49 |
50 | def __call__(self, image):
51 | if self.square_size is not None:
52 | image = image.resize((self.square_size, self.square_size))
53 | try:
54 | tmp = torch.tensor(self.transform(image)['pixel_values'][0])
55 | except:
56 | tmp = torch.tensor(self.transform(Image.new(image.mode, (32, 32), (0,0,0)))['pixel_values'][0])
57 | return tmp
58 |
59 |
60 |
61 | def load_model(model_path, device='cuda:0', precision='bf16'):
62 | config_class = VolCanoMistralConfig
63 | model_class = VolCanoMistralForCausalLM
64 | tokenizer_class = AutoTokenizer
65 | device = torch.device(device)
66 | tokenizer = AutoTokenizer.from_pretrained(
67 | model_path,
68 | cache_dir=None,
69 | use_fast=True,
70 | trust_remote_code=True
71 | )
72 |
73 | llama_config = config_class.from_pretrained(model_path)
74 | model = model_class.from_pretrained(model_path, config=llama_config)
75 |
76 | model.input_img_id = tokenizer.convert_tokens_to_ids(DEFAULT_IMG_TOKEN)
77 | model.eoc_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_EOC_TOKEN)
78 | model.boc_token_id = tokenizer.convert_tokens_to_ids(DEFAULT_BOC_TOKEN)
79 | model.tokenizer = tokenizer
80 | model.sub_image_bind = False
81 |
82 | if precision == 'bf16':
83 | model.to(torch.bfloat16)
84 | elif precision == 'fp16':
85 | model.to(torch.float16)
86 | elif precision == 'fp32':
87 | pass
88 | else:
89 | raise ValueError('precision must be fp16, bf16, or fp32')
90 | model.eval()
91 | model.to(device)
92 |
93 | resize2square = False
94 | output_vis_processor = transforms.Compose(
95 | [
96 | transforms.Resize(1024, interpolation=transforms.InterpolationMode.BILINEAR),
97 | transforms.CenterCrop(1024),
98 | # transforms.RandomHorizontalFlip(), # comment here
99 | transforms.ToTensor(),
100 | transforms.Normalize([0.5], [0.5]),
101 | ]
102 | )
103 | input_vis_processor = transforms.Compose(
104 | [
105 | transforms.Resize((448, 448) if resize2square else 448, interpolation=transforms.InterpolationMode.BILINEAR),
106 | transforms.CenterCrop(448),
107 | # transforms.RandomHorizontalFlip(), comment here
108 | transforms.ToTensor(),
109 | transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
110 | ]
111 | )
112 | if hasattr(model.vision_encoder, 'image_processor'):
113 | input_vis_processor = model.vision_encoder.image_processor
114 | if resize2square:
115 | tmp_size = input_vis_processor.size['shortest_edge']
116 | else:
117 | tmp_size = None
118 | input_vis_processor = CLIPTransform(input_vis_processor, square_size=tmp_size)
119 | # tokenizer = LlamaTokenizer.from_pretrained('eval/debug/edit_gpt_emu_tokenizer')
120 |
121 | model.image_processor = None
122 | preprocessor = VoCoT_InputProcessor(tokenizer=tokenizer, input_image_processor = input_vis_processor, use_mistral=True,
123 | output_image_processor= output_vis_processor, merge_in_out_image=True, expand2square=True, inference = True)
124 |
125 | return model, preprocessor
126 |
127 | def infer(model, preprocessor, image, query, cot=True, max_new_tokens=1024, temperature=0.0):
128 | if cot:
129 | query = ALL_IMG_TOKENS_STR + DEFAULT_GRD_TOKEN + '\n' + query + COT_ACTIVATION
130 | else:
131 | query = ALL_IMG_TOKENS_STR + '\n' + query
132 | conv = [{'from': 'human', 'value':query}]
133 | item = {'input_images': [image], 'conversation': conv}
134 | input_item = preprocessor(item)
135 | data_collator = SFT_DataCollator(tokenizer=preprocessor.tokenizer, sd_tokenizer=None)
136 | batch = data_collator([input_item])
137 | txt_res, out_imgs, txt_ids = model.condition_completion(batch, avoid_image_gen=True,
138 | max_new_tokens=max_new_tokens, temperature=temperature)
139 |
140 | return txt_res
141 |
142 |
143 | if __name__=='__main__':
144 | from PIL import Image
145 | tmp_image = Image.open('eval/debug/tmp.jpg')
146 | model_path = '/mnt/bn/yangmin-priv/luoruipu/checkpoints/LLaVA-clip336px-obj-represent-Mistral-1e-5-3072-instruct_llava+shikraCoT75per+GPTQTA+lvis-cot/'
147 | model, preprocessor = load_model(model_path,precision='fp16')
148 | res1 = infer(model, preprocessor, tmp_image, 'Is there a event "the cat is below the bed" in this image?', cot=True)
149 | res = infer(model, preprocessor, tmp_image, 'Why is the cat on the bed?', cot=True)
150 | res_no_cot = infer(model, preprocessor, tmp_image, 'Describe the image.', cot=True)
151 | print(res1)
152 | print(res)
153 | print(res_no_cot)
--------------------------------------------------------------------------------
/model/vision_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower
3 | from .eva_vit import create_eva_vit_g
4 | from .eva_vit_emu import create_eva_vit_emu
5 |
6 | def build_vision_encoder(vision_encoder_cfg, **kwargs):
7 | vision_tower = getattr(vision_encoder_cfg, 'vision_encoder', None)
8 | is_absolute_path_exists = os.path.exists(vision_tower)
9 |
10 | if vision_tower.startswith("openai") or (is_absolute_path_exists and (vision_tower.startswith("OFA-Sys") or vision_tower.startswith("laion"))):
11 | return CLIPVisionTower(vision_tower, args=vision_encoder_cfg, **kwargs)
12 | elif is_absolute_path_exists and 'eva' in vision_tower:
13 | return create_eva_vit_g(vision_encoder_cfg.vision_encoder, **kwargs)
14 | elif vision_tower == 'eva_vit_emu':
15 | # using the emu-pre-trained vit
16 | vision_encoder_path = getattr(vision_encoder_cfg, 'vision_encoder_path', None)
17 | load_encoder_ckpt = not getattr(vision_encoder_cfg, 'skip_load_vision_encoder', False)
18 | assert vision_encoder_path is not None, 'please specify the model path for emu-pre-trained vision encoder'
19 | return create_eva_vit_emu(vision_encoder_path, load_ckpt=load_encoder_ckpt)
20 | raise ValueError(f'Unknown vision tower: {vision_tower}')
21 |
--------------------------------------------------------------------------------
/model/vision_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 |
6 | class CLIPVisionTower(nn.Module):
7 | def __init__(self, vision_tower, args, delay_load=False):
8 | super().__init__()
9 |
10 | self.is_loaded = False
11 |
12 | self.vision_tower_name = vision_tower
13 | self.select_layer = args.mm_vision_select_layer
14 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
15 | self.language = getattr(args, 'language', 'english')
16 | if self.language == 'chinese':
17 | from transformers import ChineseCLIPConfig as CLIPVisionConfig
18 | else:
19 | from transformers import CLIPVisionConfig
20 | if not delay_load:
21 | self.load_model()
22 | else:
23 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
24 | self.num_features = self.embed_dim = self.vision_tower.config.hidden_size
25 |
26 | def load_model(self):
27 | if self.language == 'chinese':
28 | from transformers import ChineseCLIPVisionModel as CLIPVisionModel
29 | from transformers import ChineseCLIPImageProcessor as CLIPImageProcessor
30 | else:
31 | from transformers import CLIPVisionModel, CLIPImageProcessor
32 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
33 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
34 | self.vision_tower.requires_grad_(False)
35 | self.image_size = self.vision_tower.config.image_size
36 |
37 | self.is_loaded = True
38 |
39 | def feature_select(self, image_forward_outs):
40 | image_features = image_forward_outs.hidden_states[self.select_layer]
41 | if self.select_feature == 'patch':
42 | image_features = image_features[:, 1:]
43 | elif self.select_feature == 'cls_patch':
44 | image_features = image_features
45 | else:
46 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
47 | return image_features
48 |
49 | @torch.no_grad()
50 | def forward(self, images):
51 | if type(images) is list:
52 | image_features = []
53 | for image in images:
54 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
55 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
56 | image_features.append(image_feature)
57 | else:
58 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
59 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
60 |
61 | return image_features
62 |
63 | @property
64 | def dummy_feature(self):
65 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
66 |
67 | @property
68 | def dtype(self):
69 | return self.vision_tower.dtype
70 |
71 | @property
72 | def device(self):
73 | return self.vision_tower.device
74 |
75 | @property
76 | def config(self):
77 | if self.is_loaded:
78 | return self.vision_tower.config
79 | else:
80 | return self.cfg_only
81 |
82 | @property
83 | def hidden_size(self):
84 | return self.config.hidden_size
85 |
86 | @property
87 | def num_patches(self):
88 | return (self.config.image_size // self.config.patch_size) ** 2
89 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp==3.8.4
2 | aiosignal==1.3.1
3 | async-timeout==4.0.2
4 | attrs==22.2.0
5 | chardet==5.1.0
6 | contourpy==1.0.7
7 | cycler==0.11.0
8 | filelock==3.9.0
9 | fonttools==4.38.0
10 | frozenlist==1.3.3
11 | huggingface-hub
12 | importlib-resources==5.12.0
13 | kiwisolver==1.4.4
14 | matplotlib==3.7.0
15 | multidict==6.0.4
16 | packaging==23.0
17 | psutil==5.9.4
18 | pycocotools==2.0.6
19 | pyparsing==3.0.9
20 | python-dateutil==2.8.2
21 | pyyaml==6.0
22 | regex==2022.10.31
23 | tqdm==4.64.1
24 | timm==0.9.2
25 | spacy==3.5.1
26 | webdataset==0.2.48
27 | scikit-learn
28 | scipy
29 | yarl==1.8.2
30 | zipp==3.14.0
31 | omegaconf==2.3.0
32 | iopath==0.1.10
33 | decord==0.6.0
34 | tenacity==8.2.2
35 | pycocoevalcap
36 | sentence-transformers
37 | umap-learn
38 | notebook
39 | gradio==3.24.1
40 | gradio-client==0.0.8
41 | torch>=2.0
42 | torchvision>=0.15.0
43 | transformers==4.37.2
44 | pyiqa
45 | xformers
46 | accelerate>=0.20.3
47 | peft==0.6.2
48 | tokenizers
49 | lightning>=2.0.2
50 | open_clip_torch
51 | rouge
52 | pytorch-fid
53 | torch-fidelity
54 | torchmetrics
55 | diffusers==0.14.0
56 | opencv-python-headless
57 | prettytable
58 | deepspeed>=0.9.3
59 | datasets
60 | # flash-attn --no-build-isolation
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RupertLuo/VoCoT/d8177ca2e8e303f7d35609a000e5eff7bec319cb/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RupertLuo/VoCoT/d8177ca2e8e303f7d35609a000e5eff7bec319cb/utils/__init__.pyc
--------------------------------------------------------------------------------
/utils/count_line.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--data', type=str, default=None)
6 |
7 | args = parser.parse_args()
8 | res = json.load(open(args.data))
9 | print(len(res))
--------------------------------------------------------------------------------
/utils/format_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json, os
3 | from PIL import Image
4 | from utils.eval_util import cal_nn_iou, draw_all_box, colors, draw_all_box_colored
5 |
6 | def box2str(box, mode='special_tokens', prec=2, space=False):
7 | if mode == 'text':
8 | # using text to represent the box
9 | if space:
10 | sep = ', '
11 | else:
12 | sep = ','
13 | tmp_format = sep.join(['{' + ':.{}f'.format(prec)+'}']*4)
14 | a_box = [float(o) for o in box]
15 | return tmp_format.format(*a_box)
16 | else:
17 | raise NotImplementedError
18 |
19 | def allbox2str(objects, colored=False):
20 | s = []
21 | for i,obj in enumerate(objects):
22 | if colored:
23 | s.append('{}: {} box [{}]'.format(obj['class'], colors[i], box2str(obj['bbox'], 'text', 3, True)))
24 | else:
25 | s.append('{}: [{}]'.format(obj['class'], box2str(obj['bbox'], 'text', 3, True)))
26 | return ', '.join(s)
27 |
28 | def filter_box(box_list, iou = 0.5):
29 | '''
30 | box_list = []
31 | '''
32 | box_tensor = []
33 | for i in range(len(box_list)):
34 | box_tensor.append(list(map(float,box_list[i]['bbox'])))
35 | iou_matrix = cal_nn_iou(box_tensor)
36 | result_box = []
37 | for i in range(len(iou_matrix)):
38 | flag = 0
39 | for j in result_box:
40 | if iou_matrix[i,j]> iou:
41 | flag = 1
42 | break
43 | if flag==0:
44 | result_box.append(i)
45 | result_box = [box_list[i] for i in result_box]
46 | return result_box
47 |
48 | import random
49 | from collections import defaultdict
50 | def merge_object(info):
51 | class2box = defaultdict(list)
52 | for obj in info:
53 | class2box[obj['class']].append([float(t) for t in obj['bbox']])
54 | return class2box
55 |
56 | class OpenImagesSource:
57 | def __init__(self, path='/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/OpenImages/balance_sample_merge_filter_box.json', image_base='/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/OpenImages/', draw_box=False):
58 | self.meta = json.load(open(path))
59 | self.image_base = image_base
60 | self.draw_box = draw_box
61 |
62 | def __len__(self,):
63 | return len(self.meta)
64 |
65 | def __getitem__(self, index):
66 | item = self.meta[index]
67 | split_name = 'train_{}'.format(item['image'][0])
68 | image = os.path.join(self.image_base, split_name, '{}.jpg'.format(item['image']))
69 | if self.draw_box:
70 | image = draw_all_box_colored(Image.open(image), item['object'])
71 | return {'id': item['image'], 'image': image, 'box_info':allbox2str(item['object'], colored=self.draw_box)}
72 |
73 |
74 | class LvisSource:
75 | def __init__(self, path='/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/LVIS/lvis_v1_train_for_generation.json', image_base='/mnt/bn/yangmin-priv/luoruipu/data/multimodal-datasets/LVIS/train2017', draw_box=False):
76 | self.meta = json.load(open(path))
77 | self.image_base = image_base
78 | self.draw_box = draw_box
79 |
80 | def __len__(self,):
81 | return len(self.meta)
82 |
83 | def __getitem__(self, index):
84 | item = self.meta[index]
85 | # split_name = 'train_{}'.format(item['image'][0])
86 | # image = os.path.join(self.image_base, split_name, '{}.jpg'.format(item['image']))
87 | image = os.path.join(self.image_base, '{:012d}.jpg'.format(int(item['image'])))
88 | if self.draw_box:
89 | image = draw_all_box_colored(Image.open(image), item['object'])
90 | return {'id': item['image'], 'image': image, 'box_info':allbox2str(item['object'], colored=self.draw_box)}
91 |
92 |
93 | if __name__ == "__main__":
94 | d = OpenImagesSource()
95 | print(d[3])
96 |
97 |
--------------------------------------------------------------------------------
/utils/llava_flash_attn.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import warnings
3 |
4 | import torch
5 |
6 | import transformers
7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8 |
9 | try:
10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11 | except ImportError:
12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 |
16 | def forward(
17 | self,
18 | hidden_states: torch.Tensor,
19 | attention_mask: Optional[torch.Tensor] = None,
20 | position_ids: Optional[torch.Tensor] = None,
21 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
22 | output_attentions: bool = False,
23 | use_cache: bool = False,
24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
25 | if output_attentions:
26 | warnings.warn(
27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
28 | )
29 |
30 | bsz, q_len, _ = hidden_states.size()
31 |
32 | query_states = (
33 | self.q_proj(hidden_states)
34 | .view(bsz, q_len, self.num_heads, self.head_dim)
35 | .transpose(1, 2)
36 | )
37 | key_states = (
38 | self.k_proj(hidden_states)
39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
40 | .transpose(1, 2)
41 | )
42 | value_states = (
43 | self.v_proj(hidden_states)
44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
45 | .transpose(1, 2)
46 | ) # shape: (b, num_heads, s, head_dim)
47 |
48 | kv_seq_len = key_states.shape[-2]
49 | if past_key_value is not None:
50 | kv_seq_len += past_key_value[0].shape[-2]
51 |
52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
53 | query_states, key_states = apply_rotary_pos_emb(
54 | query_states, key_states, cos, sin, position_ids
55 | )
56 |
57 | if past_key_value is not None:
58 | # reuse k, v
59 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
60 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
61 |
62 | past_key_value = (key_states, value_states) if use_cache else None
63 |
64 | # repeat k/v heads if n_kv_heads < n_heads
65 | key_states = repeat_kv(key_states, self.num_key_value_groups)
66 | value_states = repeat_kv(value_states, self.num_key_value_groups)
67 |
68 | # Transform the data into the format required by flash attention
69 | qkv = torch.stack([query_states, key_states, value_states], dim=2)
70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
71 | key_padding_mask = attention_mask
72 |
73 | if key_padding_mask is None:
74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
75 | cu_q_lens = torch.arange(
76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
77 | )
78 | max_s = q_len
79 | output = flash_attn_unpadded_qkvpacked_func(
80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
81 | )
82 | output = output.view(bsz, q_len, -1)
83 | else:
84 | qkv = qkv.reshape(bsz, q_len, -1)
85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
87 | output_unpad = flash_attn_unpadded_qkvpacked_func(
88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
89 | )
90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
91 | output = pad_input(output_unpad, indices, bsz, q_len)
92 |
93 | return self.o_proj(output), None, past_key_value
94 |
95 |
96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
97 | # requires the attention mask to be the same as the key_padding_mask
98 | def _prepare_decoder_attention_mask(
99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
100 | ):
101 | # [bsz, seq_len]
102 | return attention_mask
103 |
104 |
105 | def replace_llama_attn_with_flash_attn():
106 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
107 | if cuda_major < 8:
108 | warnings.warn(
109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
111 | )
112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
113 | _prepare_decoder_attention_mask
114 | )
115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from logging import StreamHandler, Handler, getLevelName
3 | import os
4 | import sys
5 |
6 |
7 | # this class is a copy of logging.FileHandler except we end self.close()
8 | # at the end of each emit. While closing file and reopening file after each
9 | # write is not efficient, it allows us to see partial logs when writing to
10 | # fused Azure blobs, which is very convenient
11 | class FileHandler(StreamHandler):
12 | """
13 | A handler class which writes formatted logging records to disk files.
14 | """
15 | def __init__(self, filename, mode='a', encoding=None, delay=False):
16 | """
17 | Open the specified file and use it as the stream for logging.
18 | """
19 | # Issue #27493: add support for Path objects to be passed in
20 | filename = os.fspath(filename)
21 | #keep the absolute path, otherwise derived classes which use this
22 | #may come a cropper when the current directory changes
23 | self.baseFilename = os.path.abspath(filename)
24 | self.mode = mode
25 | self.encoding = encoding
26 | self.delay = delay
27 | if delay:
28 | #We don't open the stream, but we still need to call the
29 | #Handler constructor to set level, formatter, lock etc.
30 | Handler.__init__(self)
31 | self.stream = None
32 | else:
33 | StreamHandler.__init__(self, self._open())
34 |
35 | def close(self):
36 | """
37 | Closes the stream.
38 | """
39 | self.acquire()
40 | try:
41 | try:
42 | if self.stream:
43 | try:
44 | self.flush()
45 | finally:
46 | stream = self.stream
47 | self.stream = None
48 | if hasattr(stream, "close"):
49 | stream.close()
50 | finally:
51 | # Issue #19523: call unconditionally to
52 | # prevent a handler leak when delay is set
53 | StreamHandler.close(self)
54 | finally:
55 | self.release()
56 |
57 | def _open(self):
58 | """
59 | Open the current base file with the (original) mode and encoding.
60 | Return the resulting stream.
61 | """
62 | return open(self.baseFilename, self.mode, encoding=self.encoding)
63 |
64 | def emit(self, record):
65 | """
66 | Emit a record.
67 |
68 | If the stream was not opened because 'delay' was specified in the
69 | constructor, open it before calling the superclass's emit.
70 | """
71 | if self.stream is None:
72 | self.stream = self._open()
73 | StreamHandler.emit(self, record)
74 | self.close()
75 |
76 | def __repr__(self):
77 | level = getLevelName(self.level)
78 | return '<%s %s (%s)>' % (self.__class__.__name__, self.baseFilename, level)
79 |
80 |
81 | def setup_logger(name, save_dir, distributed_rank, filename="log.txt"):
82 | logger = logging.getLogger(name)
83 | logger.setLevel(logging.DEBUG)
84 | # don't log results for the non-master process
85 | if distributed_rank > 0:
86 | return logger
87 | ch = logging.StreamHandler(stream=sys.stdout)
88 | ch.setLevel(logging.DEBUG)
89 | # logging.disable(logging.WARNING)
90 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s")
91 | ch.setFormatter(formatter)
92 | logger.addHandler(ch)
93 |
94 | if save_dir:
95 | fh = FileHandler(os.path.join(save_dir, filename))
96 | fh.setLevel(logging.DEBUG)
97 | fh.setFormatter(formatter)
98 | logger.addHandler(fh)
99 |
100 | return logger
--------------------------------------------------------------------------------
/utils/logger.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/RupertLuo/VoCoT/d8177ca2e8e303f7d35609a000e5eff7bec319cb/utils/logger.pyc
--------------------------------------------------------------------------------
/utils/time_check.py:
--------------------------------------------------------------------------------
1 | import time
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser()
5 | parser.add_argument('--tgt_time', type=str, default=None)
6 | args = parser.parse_args()
7 | tgt_time = time.strptime(args.tgt_time, '%Y-%m-%d %X')
8 | while True:
9 | current_time = time.localtime()
10 | if current_time >= tgt_time:
11 | print('its high noon!')
12 | break
13 | print('not the right time, currently {}'.format(current_time))
14 | time.sleep(30)
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import numpy as np
4 | import matplotlib.pyplot as plt
5 | import re
6 | import textwrap
7 | import importlib
8 | from prettytable import PrettyTable
9 | import torch.distributed as dist
10 | import transformers
11 | import torch
12 | from safetensors import safe_open
13 | from PIL import Image
14 | from base64 import b64encode, b64decode
15 | import io
16 |
17 | def disable_torch_init():
18 | """
19 | Disable the redundant torch default initialization to accelerate model creation.
20 | """
21 | import torch
22 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
23 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
24 |
25 | def load_safetensor(path):
26 | tmp_dict = {}
27 | with safe_open(path, framework='pt', device=0) as f:
28 | for k in f.keys():
29 | tmp_dict[k] = f.get_tensor(k)
30 | return tmp_dict
31 |
32 |
33 | def sanitize_filename(filename):
34 | return re.sub('[^0-9a-zA-Z]+', '_', filename)
35 |
36 | def plot_images_and_text(predicted_image1, predicted_image2, groundtruth_image, generated_text, gt_text, save_dir, task_name, input_texts, input_images):
37 | task_path = os.path.join(save_dir, task_name)
38 | if not os.path.exists(task_path):
39 | os.makedirs(task_path)
40 | max_width = 50 # adjust this value based on your needs
41 |
42 | fig, ax = plt.subplots()
43 | ax.imshow(predicted_image1)
44 | generated_text = generated_text.replace("###", "").replace("[IMG0]", "")
45 | wrapped_generated_text = textwrap.fill(generated_text, max_width)
46 | ax.set_title(wrapped_generated_text, pad=20)
47 | ax.axis('off')
48 | plt.savefig(os.path.join(task_path, f"generated.jpg"), bbox_inches='tight')
49 | plt.close(fig)
50 |
51 | gt_text = gt_text.replace("$", "\$")
52 | wrapped_gt = textwrap.fill(gt_text, max_width)
53 | if predicted_image2 is not None:
54 | fig, ax = plt.subplots()
55 | ax.imshow(predicted_image2)
56 | ax.set_title(wrapped_gt, pad=20)
57 | ax.axis('off')
58 | plt.savefig(os.path.join(task_path, f"sd_baseline.jpg"), bbox_inches='tight')
59 | plt.close(fig)
60 |
61 | if groundtruth_image is not None:
62 | fig, ax = plt.subplots()
63 | groundtruth_image = groundtruth_image.float().cpu().numpy().squeeze()
64 | groundtruth_image = np.transpose(groundtruth_image, (1, 2, 0))
65 | groundtruth_image = np.uint8(groundtruth_image*255)
66 | ax.imshow(groundtruth_image)
67 | ax.set_title(wrapped_gt, pad=20)
68 | ax.axis('off')
69 | plt.savefig(os.path.join(task_path, f"gt.jpg"), bbox_inches='tight')
70 | plt.close(fig)
71 |
72 | if len(input_texts):
73 | max_width = 30
74 | length = len(input_texts)
75 | if length > 1:
76 | fig, ax = plt.subplots(1, length, figsize=(10*length, 10))
77 | for i in range(length):
78 | if i < len(input_images):
79 | ax[i].imshow(input_images[i])
80 | ax[i].set_title(textwrap.fill(input_texts[i], max_width), fontsize=28)
81 | ax[i].axis('off')
82 | else:
83 | ax[i].text(0.5, 0.5, textwrap.fill(input_texts[i], max_width), horizontalalignment='center', verticalalignment='center', fontsize=28)
84 | ax[i].axis('off')
85 | else:
86 | fig, ax = plt.subplots()
87 | ax.imshow(input_images[0])
88 | ax.set_title(textwrap.fill(input_texts[0], max_width), fontsize=28)
89 | ax.axis('off')
90 | plt.savefig(os.path.join(task_path, f"input.jpg"), bbox_inches='tight')
91 | plt.close(fig)
92 |
93 | return None
94 |
95 | def instantiate_from_config(config, inference = False, reload=False):
96 | if not "target" in config:
97 | if config == '__is_first_stage__':
98 | return None
99 | elif config == "__is_unconditional__":
100 | return None
101 | raise KeyError("Expected key `target` to instantiate.")
102 | return get_obj_from_str(config["target"], reload=reload)(**config.get("params", dict()))
103 |
104 |
105 | def get_obj_from_str(string, reload=False):
106 | module, cls = string.rsplit(".", 1)
107 | if reload:
108 | module_imp = importlib.import_module(module)
109 | importlib.reload(module_imp)
110 | return getattr(importlib.import_module(module, package=None), cls)
111 |
112 | def print_trainable_params(model):
113 | if dist.get_rank() == 0:
114 | trainable_params = [k for k,v in model.named_parameters() if v.requires_grad]
115 | trainable_params_group = {}
116 | for para in trainable_params:
117 | layer_num = re.findall(r'layers.(\d+)\.',para)
118 | if layer_num:
119 | cur_layer = int(layer_num[0])
120 | if para.replace('layers.'+layer_num[0],'layers.*') not in trainable_params_group:
121 | trainable_params_group[para.replace('layers.'+layer_num[0],'layers.*')] = layer_num[0]
122 | elif cur_layer > int(trainable_params_group[para.replace('layers.'+layer_num[0],'layers.*')]):
123 | trainable_params_group[para.replace('layers.'+layer_num[0],'layers.*')] = layer_num[0]
124 |
125 | else:
126 | trainable_params_group[para] = '0'
127 | table = PrettyTable(['Parameter Name','Max Layer Number'])
128 | for key in trainable_params_group.keys():
129 | table.add_row([key, str(int(trainable_params_group[key])+1)])
130 |
131 | print(table)
132 | total_num = sum([v.numel() for k,v in model.named_parameters()])
133 | trainable_num = sum([v.numel() for k,v in model.named_parameters() if v.requires_grad])
134 | print('Total: {:.2f}M'.format(total_num/1e6), ' Trainable: {:.2f}M'.format(trainable_num/1e6))
135 |
136 | def rank_0_print(output):
137 | if dist.get_rank() == 0:
138 | print(output)
139 |
140 | def safe_save_model_for_hf_trainer(trainer,
141 | output_dir):
142 | """Collects the state dict and dump to disk."""
143 |
144 | if trainer.args.lora:
145 | if trainer.args.should_save:
146 | trainer.model.save_pretrained(output_dir)
147 |
148 | else:
149 | if trainer.deepspeed:
150 | print('saving deepspeed model...')
151 | torch.cuda.synchronize()
152 | trainer.save_model(output_dir)
153 | return
154 |
155 | state_dict = trainer.model.state_dict()
156 | if trainer.args.should_save:
157 | cpu_state_dict = {
158 | key: value.cpu()
159 | for key, value in state_dict.items()
160 | }
161 | del state_dict
162 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
163 |
164 | def byte2image(byte_data):
165 | """
166 | convert byte to PIL image
167 | """
168 | if isinstance(byte_data, str):
169 | byte_data = b64decode(byte_data)
170 | image = Image.open(io.BytesIO(byte_data))
171 | return image
172 |
--------------------------------------------------------------------------------
/vocot_trainer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import json
3 | import os
4 | import math
5 | import numpy as np
6 | from transformers import (
7 | TrainerCallback,
8 | TrainingArguments,
9 | )
10 | from torch import nn
11 | import datasets
12 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
13 | from torch.utils.data import DataLoader, Dataset, SequentialSampler, DistributedSampler
14 | from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
15 | from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
16 | from transformers import Trainer, is_datasets_available, PreTrainedModel
17 | from transformers.deepspeed import is_deepspeed_zero3_enabled
18 | from transformers.trainer_callback import TrainerControl, TrainerState
19 | from transformers.training_args import TrainingArguments
20 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase
21 | import torch.distributed as dist
22 | from transformers.trainer_utils import EvalPrediction,seed_worker
23 | import torch
24 | import re
25 | from transformers.data.data_collator import DataCollator
26 |
27 |
28 | class VoCoTTrainer(Trainer):
29 | def __init__(self,
30 | model: Union[PreTrainedModel, nn.Module] = None,
31 | args: TrainingArguments = None,
32 | data_collator: Optional[DataCollator] = None,
33 | train_dataset: Optional[Dataset] = None,
34 | eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
35 | tokenizer: Optional[PreTrainedTokenizerBase] = None,
36 | model_init: Optional[Callable[[], PreTrainedModel]] = None,
37 | compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
38 | callbacks: Optional[List[TrainerCallback]] = None,
39 | optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
40 | preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
41 | regression_text_loss_metrics: Optional[Tuple[float]]=(0.0, 0.0)):
42 | self.regression_text_loss_metrics = regression_text_loss_metrics
43 |
44 | super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init, compute_metrics, callbacks, optimizers, preprocess_logits_for_metrics,)
45 |
46 | def log(self, logs: Dict[str, float], eval = False) -> None:
47 | """
48 | Log `logs` on the various objects watching training.
49 |
50 | Subclass and override this method to inject custom behavior.
51 |
52 | Args:
53 | logs (`Dict[str, float]`):
54 | The values to log.
55 | """
56 | if not eval:
57 | if self.state.epoch is not None:
58 | logs["epoch"] = round(self.state.epoch, 2)
59 |
60 | # logs['lora_lr'] = self.optimizer.param_groups[0]['lr']
61 | # logs['other_lr'] = self.optimizer.param_groups[1]['lr']
62 | txt_loss, reg_loss = self.regression_text_loss_metrics
63 | if txt_loss is not None:
64 | logs['text_loss'] = txt_loss
65 | if reg_loss is not None:
66 | logs['regression_loss'] = reg_loss
67 | output = {**logs, **{"step": self.state.global_step}}
68 | self.state.log_history.append(output)
69 | self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
70 | else:
71 | if self.state.epoch is not None:
72 | logs["epoch"] = round(self.state.epoch, 2)
73 | self.state.log_history.append(output)
74 | self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
75 |
76 | # def create_optimizer(self,):
77 | # if self.args.lora and self.args.tune_mm_mlp_adapter:
78 | # from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
79 | # from transformers.trainer_pt_utils import get_parameter_names
80 | # decay_parameters = get_parameter_names(self.model, ALL_LAYERNORM_LAYERS)
81 | # decay_parameters = [name for name in decay_parameters if "bias" not in name]
82 | # optimizer_grouped_parameters = [
83 | # {
84 | # "params": [
85 | # p for n, p in self.model.named_parameters() if ('lora' in n and p.requires_grad)
86 | # ],
87 | # "weight_decay": self.args.weight_decay,
88 | # 'lr': float(self.args.lora_lr) if self.args.lora_lr else self.args.learning_rate
89 | # },
90 | # {
91 | # "params": [
92 | # p for n, p in self.model.named_parameters() if (n in decay_parameters and 'lora' not in n and p.requires_grad)
93 | # ],
94 | # "weight_decay": self.args.weight_decay,
95 | # },
96 | # {
97 | # "params": [
98 | # p for n, p in self.model.named_parameters() if (n not in decay_parameters and p.requires_grad)
99 | # ],
100 | # "weight_decay": 0.0,
101 | # },
102 | # ]
103 | # optimizer_cls, optimizer_kwargs = ValleyTrainer.get_optimizer_cls_and_kwargs(self.args)
104 | # self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
105 | # else:
106 | # self.optimizer = super().create_optimizer()
107 | # return self.optimizer
108 | def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
109 | if 'gt_label' in inputs:
110 | gt_label = inputs.pop('gt_label')
111 | return super()._prepare_inputs(inputs)
112 |
113 | def compute_loss(self, model, inputs, return_outputs=False):
114 | # return super().compute_loss(model, inputs, return_outputs=False)
115 | loss, outputs = super().compute_loss(model, inputs, return_outputs=True)
116 | # record the regression and text loss
117 | text_loss = outputs['text_loss']
118 | if 'regression_loss' in outputs:
119 | regression_loss = outputs['regression_loss']
120 | else:
121 | regression_loss = None
122 | if self.args.n_gpu > 1:
123 | text_loss = text_loss.mean()
124 | if regression_loss is not None:
125 | regression_loss = regression_loss.mean()
126 | text_loss = text_loss.item()
127 | if regression_loss is not None:
128 | regression_loss = regression_loss.item()
129 | else:
130 | regression_loss = self.regression_text_loss_metrics[1]
131 | self.regression_text_loss_metrics = (text_loss, regression_loss)
132 | return loss
133 |
--------------------------------------------------------------------------------