├── .gitignore
├── LICENSE
├── README.md
├── assets
├── data.png
├── demo_all.gif
├── demo_box.gif
├── demo_point.gif
├── framework.png
├── osprey.png
├── performance.png
├── qmsht.gif
├── qyqx.gif
├── table1.png
├── table2.png
├── table3.png
├── table4.png
├── table5.png
├── table6.png
└── video_cover.png
├── dataset.md
├── demo
├── app.py
├── inference.py
└── osprey_inference.py
├── osprey
├── __init__.py
├── configs
│ ├── stage2.json
│ └── stage3.json
├── constants.py
├── conversation.py
├── data_generation
│ ├── ask_gpt.py
│ ├── concise_qa
│ │ ├── ask_example.txt
│ │ ├── res_example.txt
│ │ └── system_message.txt
│ ├── conversation
│ │ ├── ask_example.txt
│ │ ├── res_example.txt
│ │ └── system_message.txt
│ ├── data_generation_pipeline.sh
│ ├── description
│ │ ├── ask_example.txt
│ │ ├── res_example.txt
│ │ └── system_message.txt
│ ├── generate_gpt_prompt.py
│ └── gpt_data_generation.py
├── datasets
│ ├── data_modules.py
│ ├── osprey_724k.py
│ ├── stage2_data.py
│ ├── vcr.py
│ └── vg.py
├── eval
│ ├── README.md
│ ├── datasets
│ │ ├── README.md
│ │ ├── ade20k_instance_catid_mapping.txt
│ │ ├── prepare_ade20k_ins_seg.py
│ │ ├── prepare_ade20k_pan_seg.py
│ │ └── prepare_ade20k_sem_seg.py
│ ├── description
│ │ ├── answers.json
│ │ ├── prompt.json
│ │ └── questions.json
│ ├── eval_gpt.py
│ ├── eval_open_vocab_seg_detectron2.py
│ ├── ferret-bench
│ │ ├── box_refer_caption.json
│ │ └── box_refer_reason.json
│ ├── ferret_bench_eval.py
│ ├── gpt_eval.sh
│ ├── lvis_paco_eval.py
│ ├── osprey_generate_gpt_description_answer.py
│ ├── pope
│ │ └── evaluate.py
│ ├── pope_eval.py
│ ├── pope_eval.sh
│ ├── refcocog_eval.py
│ ├── rule.json
│ ├── summarize_gpt_score.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── ade20k_150_with_prompt_eng.txt
│ │ ├── cityscapes_with_prompt_eng.txt
│ │ ├── instance_evaluation.py
│ │ ├── openseg_classes.py
│ │ ├── register_ade20k_panoptic.py
│ │ └── register_cityscapes_panoptic.py
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── consolidate.py
│ ├── language_model
│ │ └── osprey_llama.py
│ ├── layer.py
│ ├── multimodal_encoder
│ │ ├── builder.py
│ │ ├── clip.py
│ │ └── clip_encoder.py
│ ├── multimodal_projector
│ │ └── builder.py
│ └── osprey_arch.py
├── train
│ ├── llama_flash_attn_monkey_patch.py
│ ├── osprey_trainer.py
│ ├── train.py
│ └── train_mem.py
└── utils.py
├── pyproject.toml
└── scripts
├── stage2.sh
├── stage3.sh
├── zero2.json
├── zero3.json
└── zero3_offload.json
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__
2 | .DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |  [](https://arxiv.org/pdf/2312.10032.pdf) [](https://huggingface.co/datasets/AntGroup-MI/Osprey-724K) [](https://youtu.be/YsxqHBBnDfk) [](http://111.0.123.204:8000/)
8 |
9 |
10 |
11 |
12 | Demo username & password: osprey
13 |
14 |
15 | ---
16 |
17 |
18 |

19 |
20 | A part of
Along the River During the Qingming Festival (清明上河图)
21 |
22 |

23 |
24 |
Spirited Away (千与千寻)
25 |
26 |
27 |
28 |
29 | 💡 Some of our other multimodal-LLM projects may interest you ✨.
30 |
31 |
32 | > [**VideoRefer Suite: Advancing Spatial-Temporal Object Understanding with Video LLM**](https://arxiv.org/abs/2501.00599)
33 | > Yuqian Yuan, Hang Zhang, Wentong Li, Zesen Cheng, Boqiang Zhang, Long Li, Xin Li, Deli Zhao, Wenqiao Zhang, Yueting Zhuang, Jianke Zhu, Lidong Bing
34 | [](https://github.com/DAMO-NLP-SG/VideoRefer) [](https://github.com/DAMO-NLP-SG/VideoRefer) [](https://arxiv.org/abs/2501.00599)
35 |
36 | > [**TokenPacker: Efficient Visual Projector for Multimodal LLM**](https://arxiv.org/abs/2407.02392)
37 | > Wentong Li*, Yuqian Yuan*, Jian Liu, Dongqi Tang, Song Wang, Jianke Zhu, Lei Zhang
38 | [](https://github.com/CircleRadon/TokenPacker) [](https://github.com/CircleRadon/TokenPacker) [](https://arxiv.org/abs/2407.02392)
39 |
40 |
41 | ## Updates 📌
42 |
43 | [2025/4/22]🔥 Our defined metrics (Sem. Sim. & Sem. IoU) on Referring Object Classification have been adopted in [Describe Anything Model](https://arxiv.org/pdf/2504.16072) (NVIDIA & UC Berkeley).
44 |
45 | [2025/2/27]🔥 Our new work, [VideoRefer Suite](https://github.com/DAMO-NLP-SG/VideoRefer), has been accept to CVPR2025! This project focuses on video referring.
46 |
47 | [2024/11/27]🔥 Our defined metrics (Sem. Sim. & Sem. IoU) on Referring Object Classification have been adopted in [ChatRex](https://arxiv.org/abs/2411.18363) (IDEA).
48 |
49 | [2024/3/29]🔥 We released [Osprey-Chat](https://huggingface.co/sunshine-lwt/Osprey-Chat-7b/tree/main) model, which exhibits better conversation and image-level understanding&reasoning capabilities.
50 |
51 | [2024/2/27]🔥 Osprey has been accepted to CVPR2024!
52 |
53 | [2024/1/15]🔥 We released the [evaluation](./osprey/eval/README.md) code.
54 |
55 | [2023/12/29]🔥 We released the training code and [Osprey-724K](https://huggingface.co/datasets/AntGroup-MI/Osprey-724K) dataset.
56 |
57 | [2023/12/18]🔥 We released the code, [osprey-7b model](https://huggingface.co/sunshine-lwt/Osprey-7b/tree/main) and [online demo](http://111.0.123.204:8000/) for Osprey.
58 |
59 |
60 | ## What is Osprey 👀
61 | Osprey is a mask-text instruction tuning approach that extends MLLMs by incorporating pixel-wise mask regions into language instructions, enabling **fine-grained visual understanding**. Based on input mask region, Osprey generate the semantic descriptions including **short description** and **detailed description**.
62 |
63 | Our Osprey can seamlessly integrate with [SAM](https://github.com/facebookresearch/segment-anything) in point-prompt, box-prompt and segmentation everything modes to generate the semantics associated with specific parts or objects.
64 |
65 |
66 |
67 | ## Watch Video Demo 🎥
68 |
69 |
70 |
71 |
72 | ## Try Our Demo 🕹️
73 | ### Online demo
74 | **Click** 👇 **to try our demo online.**
75 |
76 | [**web demo**](http://111.0.123.204:8000/)
77 |
78 | ```
79 | username: osprey
80 | password: osprey
81 | ```
82 |
83 |
84 |
85 | Point
|
86 |  |
87 |
88 |
89 | Box
|
90 |  |
91 |
92 |
93 |
94 | Everything
|
95 |  |
96 |
97 |
98 |
99 | ### Offline demo
100 | 💻 **requirments:** For this demo, it needs about `17GB` GPU memory for Osprey(15GB) and SAM(2GB).
101 |
102 | 1. First install [Gradio-Osprey-Demo](https://github.com/LiWentomng/gradio-osprey-demo).
103 | 2. Install Segment Anything.
104 | ```
105 | pip install git+https://github.com/facebookresearch/segment-anything.git
106 | ```
107 |
108 | 3. Download all the checkpoints:
109 |
110 | - [Osprey-7b](https://huggingface.co/sunshine-lwt/Osprey-7b/tree/main)
111 | - [CLIP-convnext](https://huggingface.co/laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/blob/main/open_clip_pytorch_model.bin)
112 | - [ViT-B SAM model](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth)
113 |
114 | The default path of all the checkpoints:
115 | ```
116 | ├── demo
117 | ├── checkpoints
118 | │ ├── Osprey_7b
119 | │ └── sam_vit_b_01ec64.pth
120 | └── open_clip_pytorch_model.bin
121 | ```
122 |
123 | Or change the "mm_vision_tower" in `config.json` of Osprey-7b model to the Absolute Path of `open_clip_pytorch_model.bin`.
124 |
125 | 4. Run `app.py`.
126 | ```
127 | cd demo
128 | python app.py --model checkpoints/Osprey_7b
129 | ```
130 |
131 | ## Install 🛠️
132 | 1. Clone this repository and navigate to Osprey folder
133 | ```
134 | git clone https://github.com/CircleRadon/Osprey.git
135 | cd Osprey
136 | ```
137 | 2. Install packages
138 | ```
139 | conda create -n osprey python=3.10 -y
140 | conda activate osprey
141 | pip install --upgrade pip # enable PEP 660 support
142 | pip install -e .
143 | ```
144 | 3. Install additional packages for training cases
145 | ```
146 | pip install -e ".[train]"
147 | pip install flash-attn --no-build-isolation
148 | ```
149 |
150 | ## Dataset 🌟
151 | The all datasets for training can be found in [Dataset preparation](./dataset.md).
152 |
153 | **Osprey-724K**: 🤗[Hugging Face](https://huggingface.co/datasets/AntGroup-MI/Osprey-724K)
154 |
155 | `Osprey-724K` is an instruction dataset with mask-text pairs, containing around 724K GPT-generated multimodal dialogues to encourage MLLMs for fine-grained pixel-level image understanding. It contains object-level, part-level and additional instruction samples for robustness and flexibility.
156 |
157 |
158 | ## Training 🚀
159 | - **Stage1: Image-Text Alignment Pre-training**
160 | - The pretrained projector weights for Convnext-large-CLIP can be found in [projector weights](https://huggingface.co/sunshine-lwt/osprey-v1.0-mlp2x-512px-convnext-pretrain-vicuna-7b-v1.5/tree/main).
161 |
162 | - **Stage2: Mask-Text Alignment Pre-training**
163 | - Download [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5/tree/main).
164 | - Download projector weights trained in stage1: [projector weights](https://huggingface.co/sunshine-lwt/osprey-v1.0-mlp2x-512px-convnext-pretrain-vicuna-7b-v1.5/tree/main).
165 | - Set `model_name_or_path` in `stage2.sh` to the path of `vicuna-7b-v1.5`.
166 | - Set `pretrain_mm_mlp_adapter` in `stage2.sh` to the path of `mm_projector`.
167 | - Set `vision_tower` in `stage2.sh` to the path of [Convnext-large-CLIP-model](https://huggingface.co/laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/blob/main/open_clip_pytorch_model.bin).
168 | - Run `sh scripts/stage2.sh`.
169 |
170 | - **Stage3: End-to-End Fine-tuning**
171 |
172 | - Set `model_name_or_path` in `stage2.sh` to the path of `stage2 checkpoint`.
173 | - Set `vision_tower` in `stage2.sh` to the path of [Convnext-large-CLIP-model](https://huggingface.co/laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/blob/main/open_clip_pytorch_model.bin).
174 | - Run `sh scripts/stage3.sh`.
175 |
176 |
177 | ## Checkpoints 🤖
178 |
179 | Osprey-7b model🤗: [model](https://huggingface.co/sunshine-lwt/Osprey-7b/tree/main)
180 |
181 | We also provide the checkpoint of intermediate stage2, please check [model](https://huggingface.co/sunshine-lwt/Osprey-7b-stage2/tree/main).
182 |
183 |
184 |

185 |
186 |
187 |
188 | ## Evaluation 🔎
189 | See [evaluation](./osprey/eval/README.md) for details.
190 |
191 |
192 | ## TODO List 📝
193 | - [x] Release the checkpoints, inference codes and demo.
194 | - [x] Release the dataset and training scripts.
195 | - [x] Release the evaluation code.
196 | - [x] Release the code for data generation pipeline.
197 |
198 |
199 | ## Acknowledgement 💌
200 | - [LLaVA-v1.5](https://github.com/haotian-liu/LLaVA): the codebase we built upon.
201 | - [SAM](https://github.com/facebookresearch/segment-anything): the demo uses the segmentation result from SAM as the input of Osprey.
202 |
203 |
204 | ## BibTeX 🖊️
205 | ```
206 | @misc{Osprey,
207 | title={Osprey: Pixel Understanding with Visual Instruction Tuning},
208 | author={Yuqian Yuan, Wentong Li, Jian Liu, Dongqi Tang, Xinjie Luo, Chi Qin, Lei Zhang and Jianke Zhu},
209 | year={2023},
210 | eprint={2312.10032},
211 | archivePrefix={arXiv},
212 | primaryClass={cs.CV}
213 | }
214 | ```
215 |
--------------------------------------------------------------------------------
/assets/data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/data.png
--------------------------------------------------------------------------------
/assets/demo_all.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/demo_all.gif
--------------------------------------------------------------------------------
/assets/demo_box.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/demo_box.gif
--------------------------------------------------------------------------------
/assets/demo_point.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/demo_point.gif
--------------------------------------------------------------------------------
/assets/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/framework.png
--------------------------------------------------------------------------------
/assets/osprey.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/osprey.png
--------------------------------------------------------------------------------
/assets/performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/performance.png
--------------------------------------------------------------------------------
/assets/qmsht.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/qmsht.gif
--------------------------------------------------------------------------------
/assets/qyqx.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/qyqx.gif
--------------------------------------------------------------------------------
/assets/table1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/table1.png
--------------------------------------------------------------------------------
/assets/table2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/table2.png
--------------------------------------------------------------------------------
/assets/table3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/table3.png
--------------------------------------------------------------------------------
/assets/table4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/table4.png
--------------------------------------------------------------------------------
/assets/table5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/table5.png
--------------------------------------------------------------------------------
/assets/table6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/table6.png
--------------------------------------------------------------------------------
/assets/video_cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/assets/video_cover.png
--------------------------------------------------------------------------------
/dataset.md:
--------------------------------------------------------------------------------
1 | # Dataset Preparation
2 |
3 | - Osprey-724K 🤗 [download](https://huggingface.co/datasets/AntGroup-MI/Osprey-724K)
4 |
5 | | Data | Size |
6 | | --- | ---: |
7 | | osprey_short_form.json | 57 MB |
8 | | osprey_conversation.json | 106 MB |
9 | | osprey_detail_description.json | 63.4 MB |
10 | | osprey_part_level.json | 153 MB |
11 | | osprey_lvis_positive_negative.json | 140 MB |
12 |
13 |
14 | - COCO: [train2017](http://images.cocodataset.org/zips/train2017.zip), `imgs` should contain all the images including training set and validation set.
15 | - pascal_part: [train.json](https://huggingface.co/datasets/sunshine-lwt/Osprey-TrainingData/resolve/main/pascalpart_train.json?download=true), [VOCdevkit](http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar).
16 | - partImagenet: [train_format.json](https://huggingface.co/datasets/sunshine-lwt/Osprey-TrainingData/resolve/main/partImagenet_train_format.json?download=true),
17 | [PartImageNet_OOD](https://drive.google.com/file/d/19kA8-pAxssQI0GD5H8y8KESGaALwChtx/view?usp=sharing).
18 | - refcocos: [refcoco](https://huggingface.co/datasets/sunshine-lwt/Osprey-TrainingData/resolve/main/finetune_refcoco_train_with_mask.json?download=true), [refcoco+](https://huggingface.co/datasets/sunshine-lwt/Osprey-TrainingData/resolve/main/finetune_refcoco%2B_train_with_mask.json?download=true).
19 | - vg: [vg_train_with_mask.json](https://huggingface.co/datasets/sunshine-lwt/Osprey-TrainingData/resolve/main/vg_train_with_mask.json?download=true) (mask is generated from [HQ-SAM](https://github.com/SysCV/sam-hq)), images can be downloaded from [OpendataLab](https://opendatalab.com/OpenDataLab/Visual_Genome_Dataset_V1_dot_2), `image` should contain all the vg images(VG_100K and VG_100K_2).
20 | - vcr: [vcr](https://visualcommonsense.com/download/).
21 |
22 | After downloading all of them, organize the data as follows in `./data`,
23 |
24 |
25 | ```
26 | ├── coco
27 | │ ├── annotations
28 | │ │ └── instances_train2017.json
29 | │ └── imgs
30 | ├── part data
31 | │ ├── pascal_part
32 | │ │ ├── train.json
33 | │ │ └── VOCdevkit
34 | │ └── partImagenet
35 | │ ├── train_format.json
36 | │ └── train
37 | ├── refcocos
38 | │ ├── finetune_refcoco_train_with_mask.json
39 | │ └── finetune_refcoco+_train_with_mask.json
40 | ├── Osprey-724K
41 | │ ├── osprey_short_form.json
42 | │ ├── osprey_conversation.json
43 | │ ├── osprey_detail_description.json
44 | │ ├── osprey_part_level.json
45 | │ └── osprey_lvis_positive_negative.json
46 | ├── vg
47 | │ ├── vg_train_with_mask.json
48 | │ └── image
49 | └── vcr
50 | ├── train.jsonl
51 | └── vcr1images
52 | ```
53 |
--------------------------------------------------------------------------------
/demo/inference.py:
--------------------------------------------------------------------------------
1 | import gc
2 |
3 | import numpy as np
4 | import torch
5 | from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
6 |
7 | models = {
8 | 'vit_b': './checkpoints/sam_vit_b_01ec64.pth',
9 | 'vit_l': './checkpoints/sam_vit_l_0b3195.pth',
10 | 'vit_h': './checkpoints/sam_vit_h_4b8939.pth'
11 | }
12 |
13 | def get_sam_predictor(model_type='vit_b', device='cuda'):
14 | # sam model
15 | sam = sam_model_registry[model_type](checkpoint=models[model_type])
16 | sam = sam.to(device)
17 |
18 | predictor = SamPredictor(sam)
19 |
20 | return predictor
21 |
22 | def get_mask_generator(model_type='vit_b', device='cuda'):
23 | sam = sam_model_registry[model_type](checkpoint=models[model_type])
24 | sam = sam.to(device)
25 | mask_generator = SamAutomaticMaskGenerator(
26 | model=sam)
27 | return mask_generator
28 |
29 | def run_inference(predictor: SamPredictor, input_x, selected_points,
30 | multi_object: bool = False):
31 |
32 | if len(selected_points) == 0:
33 | return []
34 |
35 | predictor.set_image(input_x)
36 |
37 | points = torch.Tensor(
38 | [p for p, _ in selected_points]
39 | ).to(predictor.device).unsqueeze(0)
40 |
41 | labels = torch.Tensor(
42 | [int(l) for _, l in selected_points]
43 | ).to(predictor.device).unsqueeze(0)
44 |
45 | transformed_points = predictor.transform.apply_coords_torch(
46 | points, input_x.shape[:2])
47 | # print(transformed_points.shape)
48 | # predict segmentation according to the boxes
49 | masks, scores, logits = predictor.predict_torch(
50 | point_coords=transformed_points,
51 | point_labels=labels,
52 | multimask_output=False,
53 | )
54 | masks = masks.cpu().detach().numpy()
55 |
56 | gc.collect()
57 | torch.cuda.empty_cache()
58 |
59 | return masks
60 |
61 |
62 | def predict_box(predictor: SamPredictor, input_x, input_box):
63 | predictor.set_image(input_x)
64 |
65 | input_boxes = torch.tensor(input_box[None, :], device=predictor.device)
66 | transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, input_x.shape[:2])
67 |
68 | masks, _, _ = predictor.predict_torch(
69 | point_coords=None,
70 | point_labels=None,
71 | boxes = transformed_boxes,
72 | multimask_output = False
73 | )
74 | masks = masks.cpu().detach().numpy()
75 |
76 | gc.collect()
77 | torch.cuda.empty_cache()
78 | return masks
--------------------------------------------------------------------------------
/demo/osprey_inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from osprey.utils import disable_torch_init
3 | from transformers import AutoTokenizer, CLIPImageProcessor
4 | from osprey.model.language_model.osprey_llama import OspreyLlamaForCausalLM
5 | from osprey.mm_utils import tokenizer_image_token
6 | from osprey.conversation import conv_templates, SeparatorStyle
7 | from osprey.constants import IMAGE_TOKEN_INDEX
8 | from osprey.train.train import DataArguments
9 |
10 | from functools import partial
11 | import os
12 | import numpy as np
13 | import cv2
14 |
15 | data_args = DataArguments()
16 | data_args.mm_use_im_start_end = False
17 | data_args.is_multimodal = True
18 |
19 | def show_mask(mask, image, random_color=True, img_trans=0.9, mask_trans=0.5, return_color=False):
20 | if random_color:
21 | color = np.concatenate([np.random.random(3)*255], axis=0)
22 | else:
23 | color = np.array([30, 144, 255])
24 | h,w = mask.shape[-2:]
25 | mask_image = mask.reshape(h,w,1)*color.reshape(1,1,-1)
26 |
27 | image = cv2.addWeighted(image, img_trans, mask_image.astype('uint8'), mask_trans , 0)
28 | if return_color:
29 | return image, mask_image
30 | else:
31 | return image
32 |
33 | class Osprey():
34 | def __init__(self, model_path, device='cuda'):
35 | disable_torch_init()
36 | model_path = os.path.expanduser(model_path)
37 | self.tokenizer = AutoTokenizer.from_pretrained(
38 | model_path,
39 | model_max_length=2048,
40 | padding_side="right",
41 | use_fast=True
42 | )
43 | self.model = OspreyLlamaForCausalLM.from_pretrained(
44 | model_path,
45 | torch_dtype=torch.bfloat16,
46 | ).to(device)
47 | self.tokenizer.pad_token = self.tokenizer.unk_token
48 |
49 | self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":512}, resample=3, do_center_crop=True, crop_size={"height": 512, "width": 512},
50 | do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
51 | image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
52 |
53 | spi_tokens = ['', '']
54 | self.tokenizer.add_tokens(spi_tokens, special_tokens=True)
55 |
56 | for m in self.model.modules():
57 | m.tokenizer = self.tokenizer
58 |
59 | vision_tower = self.model.get_vision_tower()
60 | if not vision_tower.is_loaded:
61 | vision_tower.load_model()
62 | vision_tower.to(dtype=torch.float16, device=device)
63 |
64 | begin_str = """\n\nThis provides an overview of the picture.\n"""
65 |
66 | short_question = 'Please give me a short description of . Using a short phrase.'
67 |
68 | conv = conv_templates['osprey_v1'].copy()
69 | qs = begin_str+short_question
70 | conv.append_message(conv.roles[0], qs)
71 | conv.append_message(conv.roles[1], None)
72 | prompt = conv.get_prompt()
73 |
74 | self.input_ids_short = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device)
75 |
76 | detailed_question = 'Can you give me a detailed description of ?'
77 |
78 | conv = conv_templates['osprey_v1'].copy()
79 | qs = begin_str+detailed_question
80 | conv.append_message(conv.roles[0], qs)
81 | conv.append_message(conv.roles[1], None)
82 | prompt = conv.get_prompt()
83 |
84 | self.input_ids_detailed = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.model.device)
85 |
86 | self.stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
87 |
88 |
89 |
90 | def osprey_predict(self, img, mask, type=None):
91 | image = self.image_processor.preprocess(img,
92 | do_center_crop=False,
93 | return_tensors='pt')['pixel_values'][0]
94 |
95 | image = torch.nn.functional.interpolate(image.unsqueeze(0),
96 | size=(512, 512),
97 | mode='bilinear',
98 | align_corners=False).squeeze(0)
99 |
100 | masks = torch.Tensor(mask).unsqueeze(0).to(self.model.device)
101 |
102 |
103 | if type == 'short description':
104 | input_ids = self.input_ids_short
105 | else:
106 | input_ids = self.input_ids_detailed
107 |
108 | # self.model.model.tokenizer = self.tokenizer
109 |
110 | with torch.inference_mode():
111 |
112 | self.model.orig_forward = self.model.forward
113 | self.model.forward = partial(self.model.orig_forward,
114 | img_metas=[None],
115 | masks=[masks.half()])
116 |
117 | output_ids = self.model.generate(
118 | input_ids,
119 | images=image.unsqueeze(0).half().to(self.model.device),
120 | do_sample=True,
121 | temperature=0.2,
122 | max_new_tokens=1024,
123 | use_cache=True,
124 | num_beams=1,
125 | # stopping_criteria=[stopping_criteria]
126 | )
127 |
128 | self.model.forward = self.model.orig_forward
129 |
130 | input_token_len = input_ids.shape[1]
131 | n_diff_input_output = (
132 | input_ids != output_ids[:, :input_token_len]).sum().item()
133 | if n_diff_input_output > 0:
134 | print(
135 | f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
136 | outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
137 | skip_special_tokens=True)[0]
138 |
139 | outputs = outputs.strip()
140 | if outputs.endswith(self.stop_str):
141 | outputs = outputs[:-len(self.stop_str)]
142 | outputs = outputs.strip()
143 | if ':' in outputs:
144 | outputs = outputs.split(':')[1]
145 |
146 | outputs_list = outputs.split('.')
147 | outputs_list_final = []
148 | outputs_str = ''
149 | for output in outputs_list:
150 | if output not in outputs_list_final:
151 | if output=='':
152 | continue
153 | outputs_list_final.append(output)
154 | outputs_str+=output+'.'
155 | else:
156 | break
157 | return outputs_str
--------------------------------------------------------------------------------
/osprey/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import OspreyLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/osprey/configs/stage2.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "type": "coco_data",
4 | "ann_file": "./data/coco/annotations/instances_train2017.json",
5 | "img_prefix": "./data/coco/imgs"
6 | },
7 |
8 | {
9 | "type": "RefCOCO",
10 | "ann_file": "./data/refcocos/finetune_refcoco_train_with_mask.json",
11 | "img_prefix": "./data/coco/imgs"
12 | },
13 |
14 | {
15 | "type": "RefCOCOP",
16 | "ann_file": "./data/refcocos/finetune_refcoco+_train_with_mask.json",
17 | "img_prefix": "./data/coco/imgs"
18 | },
19 |
20 | {
21 | "type": "PascalPart",
22 | "ann_file": "./data/part_data/pascal_part/train.json",
23 | "img_prefix": "./data/part_data/pascal_part/VOCdevkit/VOC2010/JPEGImages"
24 | },
25 |
26 | {
27 | "type": "PartImagenet",
28 | "ann_file": "./data/part_data/partImagenet/train_format.json",
29 | "img_prefix": "./data/part_data/partImagenet/train/"
30 | }
31 | ]
32 |
--------------------------------------------------------------------------------
/osprey/configs/stage3.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "type": "OspreyShortForm",
4 | "ann_file": "./data/Osprey-724K/osprey_short_form.json",
5 | "img_prefix": "./data/coco_imgs/"
6 | },
7 |
8 | {
9 | "type": "OspreyLVISPosNeg",
10 | "ann_file": "./data/Osprey-724K/osprey_lvis_positive_negative.json",
11 | "img_prefix": "./data/coco_imgs/"
12 | },
13 |
14 | {
15 | "type": "OspreyConversations",
16 | "ann_file": "./data/Osprey-724K/osprey_conversation.json",
17 | "img_prefix": "./data/coco_imgs/"
18 | },
19 |
20 | {
21 | "type": "OspreyDetailedDescription",
22 | "ann_file": "./data/Osprey-724K/osprey_detailed_description.json",
23 | "img_prefix": "./data/coco_imgs/"
24 | },
25 |
26 | {
27 | "type": "OspreyPartLevel",
28 | "ann_file": "./data/Osprey-724K/osprey_part_level.json",
29 | "img_prefix": "./data/coco_imgs/"
30 | },
31 |
32 | {
33 | "type": "VGDATA",
34 | "ann_file": "./data/vg/vg_train_with_mask.json",
35 | "img_prefix": "./data/vg/image"
36 | },
37 |
38 | {
39 | "type": "vcr",
40 | "ann_file": "./data/vcr/train.jsonl",
41 | "img_prefix": "./data/vcr/vcr1images"
42 | }
43 | ]
44 |
--------------------------------------------------------------------------------
/osprey/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 |
--------------------------------------------------------------------------------
/osprey/data_generation/ask_gpt.py:
--------------------------------------------------------------------------------
1 | import openai
2 |
3 | class askGPT():
4 | def __init__(self):
5 | # fill in the api key here
6 | openai.api_key = xxx
7 |
8 | def ask_gpt(self, question):
9 | with open('description/system_message.txt', 'r') as f:
10 | system_message = f.read()
11 | with open('description/ask_example.txt', 'r') as f:
12 | example_ask = f.read()
13 | with open('description/res_example.txt', 'r') as f:
14 | example_res = f.read()
15 | completion = openai.ChatCompletion.create(
16 | model="gpt-4",
17 | messages=[
18 | {"role": "system", "content": system_message},
19 | {"role": "user", "content": example_ask},
20 | {"role": "assistant", "content": example_res},
21 | {"role": "user", "content": question}
22 | ]
23 | )
24 | return completion.choices[0].message['content']
25 |
26 | def ask_gpt_short_conversation(self, question):
27 | with open('concise_qa/system_message.txt', 'r') as f:
28 | system_message = f.read()
29 | with open('concise_qa/ask_example.txt', 'r') as f:
30 | example_ask = f.read()
31 | with open('concise_qa/res_example.txt', 'r') as f:
32 | example_res = f.read()
33 | completion = openai.ChatCompletion.create(
34 | model="gpt-4",
35 | messages=[
36 | {"role": "system", "content": system_message},
37 | {"role": "user", "content": example_ask},
38 | {"role": "assistant", "content": example_res},
39 | {"role": "user", "content": question}
40 | ]
41 | )
42 | return completion.choices[0].message['content']
43 |
44 | def ask_gpt_conversation(self, question):
45 | with open('conversation/system_message.txt', 'r') as f:
46 | system_message = f.read()
47 | with open('conversation/ask_example.txt', 'r') as f:
48 | example_ask = f.read()
49 | with open('conversation/res_example.txt', 'r') as f:
50 | example_res = f.read()
51 | completion = openai.ChatCompletion.create(
52 | model="gpt-4",
53 | messages=[
54 | {"role": "system", "content": system_message},
55 | {"role": "user", "content": example_ask},
56 | {"role": "assistant", "content": example_res},
57 | {"role": "user", "content": question}
58 | ]
59 | )
60 | return completion.choices[0].message['content']
61 |
62 |
--------------------------------------------------------------------------------
/osprey/data_generation/concise_qa/ask_example.txt:
--------------------------------------------------------------------------------
1 | Whole description: "The image presents a lively market scene with a group of people buying fruits and bags. There are multiple individuals in the market, all browsing through the fresh produce available. A significant variety of fruits are showcased in the market, including bananas, oranges, and apples. Bananas can be seen in several groups, with some green and yellow bananas occupying different areas of the market. Meanwhile, oranges and apples are displayed in smaller sections among the fruits. In addition to fruits, handbags are also being sold at the market, attracting the attention of the customers. Overall, the market bustles with activity as people gather around the fresh fruits and bags, contemplating their purchases."
2 |
3 | Description of each region is listed below:
4 |
5 | (person: [0.507,0.409,0.698,0.740]):
6 | gray shirt wearing glasses
7 | woman with gray shirt standing next to man
8 | woman in gray shirt facing camera on right
9 | the woman in the grey shirt with a watch on her wrist ..
10 | a short haired woman in jeans shopping
11 |
12 | (person: [0.243,0.469,0.558,0.746]):
13 | navy blue shirt
14 | woman back in blue
15 | the lady with the blue shirt
16 | the back of an older woman with her hair in a barrette with a blue jacket on
17 | a woman is wearing blue sweater
18 |
19 |
--------------------------------------------------------------------------------
/osprey/data_generation/concise_qa/res_example.txt:
--------------------------------------------------------------------------------
1 | Question 1: What is the woman in wearing on her wrist?
2 |
3 | Answer 1: Watch.
4 |
5 | Question 2: What is the color of the older woman's shirt in ?
6 |
7 | Answer 2: Navy blue.
8 |
9 | Question 3: What hair accessory does the older woman in have?
10 |
11 | Answer 3: Barrette.
12 |
13 | Question 4: Which direction is the older woman in facing?
14 |
15 | Answer 4: Away.
16 |
17 |
--------------------------------------------------------------------------------
/osprey/data_generation/concise_qa/system_message.txt:
--------------------------------------------------------------------------------
1 | You are an AI visual assistant, and you are seeing several object regions in a single image. What you see are provided with a detailed description for the whole image and each object region in this image, describing you are looking at. Answer all questions as you are seeing the image.
2 | The location of each object region is given in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y.
3 | Design a conversation between you and a person asking about each object region of this image. The answers must be in one word or one phrase. Ask diverse questions and give corresponding answers. All the regions given should be mentioned in the questions, when referring to each region, use , , etc.
4 | Include questions asking about the visual content of each object region in the image, including the object category, object type, object color, object actions, object locations, relative positions between objects and other attributes etc. Only include questions that have definite answers:
5 | (1) one can see the content in the object region of this image that the question asks about and can answer confidently;
6 | (2) one can determine confidently from the object region of this image that it is not in the image.
7 | Do not ask any question that cannot be answered confidently.
8 | Do not ask any question that is not mentioned.
9 | Do not ask any question that cannot be answered with one word or phrase.
10 |
11 | Most importantly, the answer must be in one word or short phrase.
--------------------------------------------------------------------------------
/osprey/data_generation/conversation/ask_example.txt:
--------------------------------------------------------------------------------
1 | Whole description: "The image presents a lively market scene with a group of people buying fruits and bags. There are multiple individuals in the market, all browsing through the fresh produce available. A significant variety of fruits are showcased in the market, including bananas, oranges, and apples. Bananas can be seen in several groups, with some green and yellow bananas occupying different areas of the market. Meanwhile, oranges and apples are displayed in smaller sections among the fruits. In addition to fruits, handbags are also being sold at the market, attracting the attention of the customers. Overall, the market bustles with activity as people gather around the fresh fruits and bags, contemplating their purchases."
2 |
3 | Description of each region is listed below:
4 |
5 | (person: [0.507,0.409,0.698,0.740]):
6 | gray shirt wearing glasses
7 | woman with gray shirt standing next to man
8 | woman in gray shirt facing camera on right
9 | the woman in the grey shirt with a watch on her wrist ..
10 | a short haired woman in jeans shopping
11 |
12 | (person: [0.243,0.469,0.558,0.746]):
13 | navy blue shirt
14 | woman back in blue
15 | the lady with the blue shirt
16 | the back of an older woman with her hair in a barrette with a blue jacket on
17 | a woman is wearing blue sweater
18 |
19 |
--------------------------------------------------------------------------------
/osprey/data_generation/conversation/res_example.txt:
--------------------------------------------------------------------------------
1 | Question 1: Can you describe the woman in and what she is doing?
2 |
3 | Answer 1: The woman in is wearing a gray shirt, jeans, and glasses. She has short hair and a watch on her wrist. She is shopping in the lively market, browsing through the fresh produce, such as fruits and bags.
4 |
5 | Question 2: What can you tell me about the woman in ?
6 |
7 | Answer 2: The woman in is wearing a navy blue shirt or sweater. She has her back towards the camera, with her hair styled in a barrette. She appears to be an older woman, also browsing through the market, looking at the various items for sale.
8 |
9 | Question 3: Can you compare the positions of the women in and within the image?
10 |
11 | Answer 3: The woman in is positioned more towards the right side of the image, facing the camera. On the other hand, the woman in is located on the left side of the image with her back to the camera. Both women are engaged in shopping at the market, surrounded by fresh fruits and bags.
12 |
13 | Question 4: Which fruits are visible in the image, and are there any other types of merchandise for sale?
14 |
15 | Answer 4: Bananas, oranges, and apples are visible in the image as the showcased fruits. Besides fruits, there are handbags also being sold at the market. These items are attracting the attention of the customers as they explore the market and contemplate their purchases.
16 |
17 | Question 5: Based on the image, how would you describe the overall atmosphere of the market?
18 |
19 | Answer 5: The market has a bustling and lively atmosphere. There's a sense of excitement and energy as people gather around the fresh fruits and bags for sale. The variety of items available, as well as the interactions among customers, contribute to the vibrant ambience of the market scene.
--------------------------------------------------------------------------------
/osprey/data_generation/conversation/system_message.txt:
--------------------------------------------------------------------------------
1 | You are an AI visual assistant, and you are seeing several object regions in a single image. What you see are provided with a detailed description for the whole image and each object region in this image, describing you are looking at. Answer all questions as you are seeing the image.
2 | The location of each object region is given in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y.
3 | Design a conversation between you and a person asking about each object region of this image. The answers should be in a tone that a visual AI assistant is seeing the image and answering the question. Ask diverse questions and give corresponding answers. All the regions given should be mentioned in the questions, when referring to each region, use , , etc.
4 | Include questions asking about the visual content of each object region in the image, including the object category, object type, object color, object actions, object locations, relative positions between objects and other attributes etc. Only include questions that have definite answers:
5 | (1) one can see the content in the object region of this image that the question asks about and can answer confidently;
6 | (2) one can determine confidently from the object region of this image that it is not in the image.
7 | Do not ask any question that cannot be answered confidently.
8 | Also include complex questions that are relevant to the content of each object region in the image, for example, asking about background knowledge of the objects, asking to discuss about events happening in the image, etc. Again, do not ask about uncertain details.
9 | Provide detailed answers when answering complex questions. For example, give detailed examples or reasoning steps to make the content more convincing and well-organized. You can include multiple paragraphs if necessary.
--------------------------------------------------------------------------------
/osprey/data_generation/data_generation_pipeline.sh:
--------------------------------------------------------------------------------
1 | # description
2 | python gpt_data_generation.py \
3 | --type description \
4 | --outputfile description_gpt4_data.json
5 |
6 | # conversation
7 | python gpt_data_generation.py \
8 | --type conversation \
9 | --outputfile conversation_gpt4_data.json
10 |
11 | # short-form conversation
12 | python gpt_data_generation.py \
13 | --type short-form \
14 | --outputfile short_form_conversation_gpt4_data.json
15 |
--------------------------------------------------------------------------------
/osprey/data_generation/description/ask_example.txt:
--------------------------------------------------------------------------------
1 | The detailed description of this image:
2 | In the image, a man wearing a white apron and a ball cap is standing in a kitchen. He appears to be a chef or a baker and is actively sliding a pizza into a brick oven. The kitchen is well-equipped with various tools and utensils, such as spoons and bowls, which are placed on the dining table and countertops. A chair is positioned near the table, and multiple bowls can be found throughout the kitchen. The scene conveys a sense of preparation and anticipation for the dish being cooked.
3 | chair: [0.437,0.840,0.808,1.000]
4 | person: [0.530,0.881,0.998,0.987]
5 | person: [0.223,0.202,0.580,0.954]
6 | spoon: [0.068,0.423,0.089,0.435]
7 | Specially, there are 2 special regions: and .
8 | For each one region, you receive several sentences as the description of this region in this image you are observing.
9 | For (person: [0.223,0.202,0.580,0.954]):
10 | man chef
11 | man cooking
12 |
13 | For (person: [0.530,0.881,0.998,0.987]):
14 | i think its a table infront of the chair
15 | top of head bottom right corner
16 | part of head seen near chair
--------------------------------------------------------------------------------
/osprey/data_generation/description/res_example.txt:
--------------------------------------------------------------------------------
1 | : A man wearing a white apron, possibly a chef or a baker, is actively cooking in a well-equipped kitchen. He is standing near the brick oven and appears to be sliding a pizza into it.
2 |
3 | : The top of a person's head is partially visible near the chair, it appears to be a brown-haired individual, though not much else can be discerned from this angle.
--------------------------------------------------------------------------------
/osprey/data_generation/description/system_message.txt:
--------------------------------------------------------------------------------
1 | You are an AI visual assistant that can analyze a single image. You receive a detailed description/several descriptions of this image. In addition, most object locations within the image are given, along with detailed coordinates.
2 | These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1.These values correspond to the top left x, top left y, bottom right x, and bottom right y.
3 | Your role is to give a detailed description of each special region in the image. Instead of directly mentioning the bounding box coordinates, utilize this data to explain each region using natural language. Include details like object category, object type, object color, attributes of the object, object locations, object state and other attributes.
4 | When using the information from the image and object region captions and coordinates, directly explain the region, and do not mention that the information source is the caption or the bounding box. Always answer as if you are directly looking at each region. Provide a direct answer without mention "this region". The answer template is: ": ..."
--------------------------------------------------------------------------------
/osprey/data_generation/gpt_data_generation.py:
--------------------------------------------------------------------------------
1 | from ask_gpt import askGPT
2 | from generate_gpt_prompt import GeneratePrompt, COCODataset
3 | import json
4 | from tqdm import tqdm
5 | import re
6 | import argparse
7 |
8 | QUESTIONS = [
9 | 'Can you provide me with a detailed description of the region in the picture marked by ?',
10 | "I'm curious about the region represented by in the picture. Could you describe it in detail?",
11 | 'What can you tell me about the region indicated by in the image?',
12 | "I'd like to know more about the area in the photo labeled . Can you give me a detailed description?",
13 | 'Could you describe the region shown as in the picture in great detail?',
14 | 'What details can you give me about the region outlined by in the photo?',
15 | 'Please provide me with a comprehensive description of the region marked with in the image.',
16 | 'Can you give me a detailed account of the region labeled as in the picture?',
17 | "I'm interested in learning more about the region represented by in the photo. Can you describe it in detail?",
18 | 'What is the region outlined by in the picture like? Could you give me a detailed description?',
19 | 'Can you provide me with a detailed description of the region in the picture marked by , please?',
20 | "I'm curious about the region represented by in the picture. Could you describe it in detail, please?",
21 | 'What can you tell me about the region indicated by in the image, exactly?',
22 | "I'd like to know more about the area in the photo labeled , please. Can you give me a detailed description?",
23 | 'Could you describe the region shown as in the picture in great detail, please?',
24 | 'What details can you give me about the region outlined by in the photo, please?',
25 | 'Please provide me with a comprehensive description of the region marked with in the image, please.',
26 | 'Can you give me a detailed account of the region labeled as in the picture, please?',
27 | "I'm interested in learning more about the region represented by in the photo. Can you describe it in detail, please?",
28 | 'What is the region outlined by in the picture like, please? Could you give me a detailed description?',
29 | 'Please describe the region in the image in detail.',
30 | 'Can you offer a thorough analysis of the region in the image?',
31 | 'Could you elaborate on the region highlighted by in the picture provided?',
32 | 'Please share more information about the zone emphasized with in the photo.',
33 | 'What insights can you give ablout the area denoted by in the image presented?',
34 | 'Can you share a comprehensive rundown of the region denoted by in the presented image?',
35 | "I'd like to know more about the region highlighted by in the picture provided.",
36 | 'Work through the important details of the area in the image.',
37 | 'Illustrate the area represtented by through a descriptive explanation.',
38 | 'Examine the region closely and share its details.'
39 | ]
40 |
41 | class COCODataConvert():
42 | def __init__(self):
43 | self.gpt = askGPT()
44 | self.generate_gpt_prompt = GeneratePrompt()
45 | self.coco = COCODataset()
46 |
47 | def generate_conversation(self, output_file):
48 | results = []
49 | imgs = self.coco.img_ids
50 | sum = 0
51 | for id in (tqdm(imgs)):
52 | result = {}
53 | prompt, annotations, num_boxes, height, width = self.generate_gpt_prompt.load_data_and_generate_gpt_prompt_description(id,1)
54 | # print(prompt)
55 | if prompt == None:
56 | # print("none")
57 | continue
58 |
59 | while True:
60 | try:
61 | ret = self.gpt.ask_gpt_conversation(prompt)
62 | description =ret.json()['data']['reply']
63 |
64 | conversations = []
65 | description_list = description.split('\n\n')
66 | for i, des in enumerate(description_list):
67 | conv = {}
68 | if i%2==0:
69 | conv['from'] = "human"
70 | conv['value'] = re.findall(r"Question.*:\ (.*)",des)[0]
71 | else:
72 | conv['from'] = "gpt"
73 | conv['value'] = re.findall(r"Answer.*:\ (.*)",des)[0]
74 | conversations.append(conv)
75 | break
76 | except:
77 | print(ret)
78 |
79 |
80 | img_info = self.coco.coco.load_imgs([id])[0]
81 | result['id'] = id
82 | result['file_name'] = img_info['file_name']
83 | result['conversations'] = conversations
84 | result['annotation'] = annotations
85 | result['height'] = height
86 | result['width'] = width
87 | results.append(result)
88 |
89 | sum+=1
90 | print("num:", sum)
91 |
92 | if sum%100==0:
93 | f = json.dumps(results)
94 | f2 = open(output_file, 'w')
95 | f2.write(f)
96 | f2.close()
97 |
98 | f = json.dumps(results)
99 | f2 = open(output_file, 'w')
100 | f2.write(f)
101 | f2.close()
102 |
103 |
104 | def generate_short_conversation(self, output_file):
105 | results = []
106 | imgs = self.coco.img_ids
107 | sum = 0
108 |
109 | for id in (tqdm(imgs)):
110 | # print(id)
111 | result = {}
112 | prompt, annotations, num_boxes, height, width = self.generate_gpt_prompt.load_data_and_generate_gpt_prompt_description(id,1)
113 | # print(prompt)
114 | if prompt == None:
115 | # print("none")
116 | continue
117 | while True:
118 | try:
119 | ret = self.gpt.ask_gpt_short_conversation(prompt)
120 | description =ret.json()['data']['reply']
121 |
122 | conversations = []
123 | description_list = description.split('\n\n')
124 |
125 | for i, des in enumerate(description_list):
126 | conv = {}
127 | if i%2==0:
128 | conv['from'] = "human"
129 | conv['value'] = re.findall(r"Question.*:\ (.*)",des)[0]
130 | else:
131 | conv['from'] = "gpt"
132 | conv['value'] = re.findall(r"Answer.*:\ (.*)",des)[0]
133 | conversations.append(conv)
134 | break
135 | except:
136 | print(ret)
137 |
138 | img_info = self.coco.coco.load_imgs([id])[0]
139 | result['id'] = id
140 | result['file_name'] = img_info['file_name']
141 | result['conversations'] = conversations
142 | result['annotation'] = annotations
143 | result['height'] = height
144 | result['width'] = width
145 | results.append(result)
146 |
147 | sum+=1
148 | print("num:",sum)
149 |
150 | if sum%100==0:
151 | f = json.dumps(results)
152 | f2 = open(output_file, 'w')
153 | f2.write(f)
154 | f2.close()
155 |
156 | f = json.dumps(results)
157 | f2 = open(output_file, 'w')
158 | f2.write(f)
159 | f2.close()
160 |
161 |
162 | def generate_descriptions(self, output_file):
163 | results = []
164 | imgs = self.coco.img_ids
165 | sum = 0
166 | for id in tqdm(imgs):
167 | result = {}
168 | prompt, annotations, num_boxes, height, width = self.generate_gpt_prompt.load_data_and_generate_gpt_prompt_description(id)
169 | # print(prompt)
170 |
171 | if prompt == None:
172 | # print("None")
173 | continue
174 |
175 | while True:
176 | try:
177 | description = self.gpt.ask_gpt(prompt)
178 |
179 | description_list = description.split('\n\n')
180 | break
181 | except:
182 | print(description)
183 |
184 |
185 | img_info = self.coco.coco.load_imgs([id])[0]
186 | result['id'] = id
187 | result['file_name'] = img_info['file_name']
188 | result['description'] = description_list
189 | result['annotation'] = annotations
190 | result['height'] = height
191 | result['width'] = width
192 | results.append(result)
193 |
194 | sum+=1
195 | print("num:",sum)
196 |
197 | if sum%100==0:
198 | f = json.dumps(results)
199 | f2 = open(output_file, 'w')
200 | f2.write(f)
201 | f2.close()
202 |
203 | f = json.dumps(results)
204 | f2 = open(output_file, 'w')
205 | f2.write(f)
206 | f2.close()
207 |
208 |
209 | if __name__ == '__main__':
210 | parser = argparse.ArgumentParser(description='data generation pipline', formatter_class=argparse.RawTextHelpFormatter)
211 | parser.add_argument('--type', help='generate data type', default='description')
212 | parser.add_argument('--outputfile', help='output file name', default='description_gpt4_data.json')
213 | args = parser.parse_args()
214 |
215 | convert = COCODataConvert()
216 |
217 | if args.type=='description':
218 | convert.generate_descriptions(args.output_file)
219 | elif args.type=='conversation':
220 | convert.generate_conversation(args.output_file)
221 | elif args.type=='short-form':
222 | convert.generate_short_conversation(args.output_file)
223 | else:
224 | raise NotImplementedError
225 |
--------------------------------------------------------------------------------
/osprey/datasets/data_modules.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import torch
3 | import transformers
4 | from torch.utils.data import ConcatDataset
5 | import json
6 | from osprey.constants import IGNORE_INDEX
7 |
8 | from .stage2_data import COCODataset, RefCOCO, RefCOCOP
9 | from .vcr import VCRDataset
10 | from .vg import VGDATA
11 | from .stage2_data import PascalPart
12 | from .stage2_data import PartImagenet
13 | from .osprey_724k import OspreyDetailedDescription, OspreyConversations, OspreyShortForm, OspreyPartLevel, OspreyLVISPosNeg
14 |
15 | @dataclass
16 | class DataCollatorForDetDataset(object):
17 |
18 | tokenizer: transformers.PreTrainedTokenizer
19 | def __call__(self, instances):
20 |
21 | input_ids, labels, img_metas, masks = tuple([instance.get(key,None) for instance in instances]
22 | for key in ('input_ids',
23 | 'labels',
24 | 'img_metas',
25 | 'masks'))
26 | input_ids = torch.nn.utils.rnn.pad_sequence(
27 | input_ids,
28 | batch_first=True,
29 | padding_value=self.tokenizer.pad_token_id)
30 | labels = torch.nn.utils.rnn.pad_sequence(labels,
31 | batch_first=True,
32 | padding_value=IGNORE_INDEX)
33 |
34 | batch = dict(
35 | input_ids=input_ids,
36 | labels=labels,
37 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
38 | img_metas=img_metas,
39 | masks = masks
40 | )
41 |
42 | if 'image' in instances[0]:
43 | images = [instance['image'] for instance in instances]
44 | if all(x is not None and x.shape == images[0].shape for x in images):
45 | batch['images'] = torch.stack(images)
46 | else:
47 | batch['images'] = images
48 |
49 | return batch
50 |
51 | def make_multitask_data_module(tokenizer,
52 | data_args) :
53 | """Make dataset and collator for supervised fine-tuning."""
54 |
55 | if data_args.dataset_config is not None:
56 | dataset_config = json.load(open(data_args.dataset_config))
57 |
58 | train_dataset = build_osprey_dataset(dataset_config,
59 | tokenizer=tokenizer,
60 | data_args=data_args)
61 |
62 | data_collator = DataCollatorForDetDataset(tokenizer=tokenizer)
63 |
64 | return dict(train_dataset=train_dataset,
65 | eval_dataset=None,
66 | data_collator=data_collator)
67 |
68 | def build_osprey_dataset(dataset_config,
69 | tokenizer=None,
70 | data_args=None,
71 | **kwargs):
72 | if isinstance(dataset_config, list):
73 | datasets = []
74 | for cfg in dataset_config:
75 | temp_dataset = build_osprey_dataset(cfg, tokenizer=tokenizer, data_args=data_args, **kwargs)
76 | datasets.append(temp_dataset)
77 |
78 | for dataset in datasets:
79 | print(type(dataset), f'len = {len(dataset)}')
80 |
81 | return ConcatDataset(datasets)
82 |
83 | dataset_type = dataset_config.pop('type')
84 |
85 | if dataset_type == 'coco_data':
86 | dataset = COCODataset(
87 | **dataset_config,
88 | tokenizer=tokenizer,
89 | data_args=data_args,
90 | **kwargs,
91 | )
92 |
93 | elif dataset_type == 'vcr':
94 | dataset = VCRDataset(
95 | **dataset_config,
96 | tokenizer=tokenizer,
97 | data_args=data_args,
98 | **kwargs,
99 | )
100 | elif dataset_type == 'VGDATA':
101 | dataset = VGDATA(
102 | **dataset_config,
103 | tokenizer=tokenizer,
104 | data_args=data_args,
105 | **kwargs,
106 | )
107 | elif dataset_type == 'RefCOCO':
108 | dataset = RefCOCO(
109 | **dataset_config,
110 | tokenizer=tokenizer,
111 | data_args=data_args,
112 | **kwargs,
113 | )
114 | elif dataset_type == 'RefCOCOP':
115 | dataset = RefCOCOP(
116 | **dataset_config,
117 | tokenizer=tokenizer,
118 | data_args=data_args,
119 | **kwargs,
120 | )
121 | elif dataset_type == 'PascalPart':
122 | dataset = PascalPart(
123 | **dataset_config,
124 | tokenizer=tokenizer,
125 | data_args=data_args,
126 | **kwargs,
127 | )
128 | elif dataset_type == 'PartImagenet':
129 | dataset = PartImagenet(
130 | **dataset_config,
131 | tokenizer=tokenizer,
132 | data_args=data_args,
133 | **kwargs,
134 | )
135 | elif dataset_type == 'OspreyDetailedDescription':
136 | dataset = OspreyDetailedDescription(
137 | **dataset_config,
138 | tokenizer=tokenizer,
139 | data_args=data_args,
140 | **kwargs,
141 | )
142 | elif dataset_type == 'OspreyConversations':
143 | dataset = OspreyConversations(
144 | **dataset_config,
145 | tokenizer=tokenizer,
146 | data_args=data_args,
147 | **kwargs,
148 | )
149 | elif dataset_type == 'OspreyShortForm':
150 | dataset = OspreyShortForm(
151 | **dataset_config,
152 | tokenizer=tokenizer,
153 | data_args=data_args,
154 | **kwargs,
155 | )
156 | elif dataset_type == 'OspreyPartLevel':
157 | dataset = OspreyPartLevel(
158 | **dataset_config,
159 | tokenizer=tokenizer,
160 | data_args=data_args,
161 | **kwargs,
162 | )
163 | elif dataset_type == 'OspreyLVISPosNeg':
164 | dataset = OspreyLVISPosNeg(
165 | **dataset_config,
166 | tokenizer=tokenizer,
167 | data_args=data_args,
168 | **kwargs,
169 | )
170 |
171 | else:
172 | raise NotImplementedError
173 |
174 | return dataset
175 |
176 |
177 |
178 | class ConcatDataset(ConcatDataset):
179 | def __init__(self, datasets):
180 | super().__init__(datasets)
181 |
182 | def collater(self, samples):
183 |
184 | all_keys = set()
185 | for s in samples:
186 | all_keys.update(s)
187 |
188 | shared_keys = all_keys
189 | for s in samples:
190 | shared_keys = shared_keys & set(s.keys())
191 |
192 | samples_shared_keys = []
193 | for s in samples:
194 | samples_shared_keys.append({k: s[k] for k in s.keys() if k in shared_keys})
195 |
196 | return self.datasets[0].collater(samples_shared_keys)
197 |
198 |
--------------------------------------------------------------------------------
/osprey/datasets/vcr.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is largely based on https://github.com/jshilong/GPT4RoI/blob/main/gpt4roi/datasets/vcr.py
3 | """
4 | import copy
5 | import json
6 | import os
7 | import random
8 | from tkinter import N
9 |
10 | import numpy as np
11 | import torch
12 | from PIL import Image
13 | from torch.utils.data import Dataset
14 | from matplotlib import path
15 | from matplotlib import pyplot as plt
16 | from osprey.train.train import preprocess, preprocess_multimodal
17 |
18 | WHY_QUESTIONS = [
19 | 'why?',
20 | 'why',
21 | "What's the rationale for your decision?",
22 | 'What led you to that conclusion?',
23 | "What's the reasoning behind your opinion?",
24 | 'Why do you believe that to be true?',
25 | 'Can you explain the basis for your thinking?',
26 | 'What factors influenced your perspective?',
27 | 'How did you arrive at that perspective?',
28 | 'What evidence supports your viewpoint?',
29 | 'What makes you think that way?',
30 | "What's the logic behind your argument?",
31 | 'Can you provide some context for your opinion?',
32 | "What's the basis for your assertion?",
33 | 'Why do you hold that belief?',
34 | 'What experiences have shaped your perspective?',
35 | 'What assumptions underlie your reasoning?',
36 | "What's the foundation of your assertion?",
37 | "What's the source of your reasoning?",
38 | "What's the motivation behind your decision?",
39 | "What's the impetus for your belief?",
40 | "What's the driving force behind your conclusion?",
41 | 'Why do you think that?',
42 | "What's your reasoning?",
43 | 'What makes you say that?',
44 | 'Why do you feel that way?',
45 | "What's the story behind that?",
46 | "What's your thought process?",
47 | "What's the deal with that?",
48 | "What's the logic behind it?",
49 | 'Why do you believe that?',
50 | "What's the real deal here?",
51 | "What's the reason behind it?",
52 | "What's the thought process behind your decision?",
53 | "What's the rationale for your opinion?",
54 | 'Why do you have that impression?',
55 | "What's the background to that?",
56 | "What's the evidence that supports your view?",
57 | "What's the explanation for that?"
58 | ]
59 |
60 | Ref_WAY = [
61 | 'There are in the image,',
62 | 'There are some regions ,',
63 | 'Given ,',
64 | 'Given in the image,',
65 | ',',
66 | 'Several regions are in the image,',
67 | ' in the given image,'
68 | ]
69 |
70 | def _spaced_points(low, high,n):
71 | """ We want n points between low and high, but we don't want them to touch either side"""
72 | padding = (high-low)/(n*2)
73 | return np.linspace(low + padding, high-padding, num=n)
74 |
75 | def make_mask(height, width, box, polygons_list):
76 | """
77 | Mask size: int about how big mask will be
78 | box: [x1, y1, x2, y2, conf.]
79 | polygons_list: List of polygons that go inside the box
80 | """
81 | mask = np.zeros((height, width), dtype=np.bool_)
82 |
83 | xy = np.meshgrid(_spaced_points(box[0], box[2], n=width),
84 | _spaced_points(box[1], box[3], n=height))
85 | xy_flat = np.stack(xy, 2).reshape((-1, 2))
86 |
87 | for polygon in polygons_list:
88 | polygon_path = path.Path(polygon)
89 | mask |= polygon_path.contains_points(xy_flat).reshape((height, width))
90 | return mask.astype(np.float32)
91 |
92 | class VCRDataset(Dataset):
93 | CLASSES = ('object',)
94 |
95 | def __init__(self,
96 | tokenizer,
97 | data_args=None,
98 | ann_file=None,
99 | img_prefix=None,
100 |
101 | ):
102 | super(VCRDataset, self).__init__()
103 |
104 |
105 | self.img_prefix = img_prefix
106 |
107 | self.tokenizer = tokenizer
108 |
109 | self.data_args = data_args
110 |
111 | self.begin_str = """.\nThis provides an overview of the picture.\n"""
112 | self.data_infos = self.load_annotations(ann_file)
113 | print('normal_vcr', len(self.data_infos))
114 |
115 | def load_annotations(self, ann_file):
116 |
117 | with open(ann_file, 'r') as f:
118 | ann_list = [json.loads(line) for line in f]
119 | data_infos = []
120 |
121 | import re
122 |
123 | def replace_numbers_with_tags(s, class_names):
124 | pattern = r'\b(\d+)\b'
125 | try:
126 | result = re.sub(pattern, lambda match: f'{class_names[int(match.group(1))]} at region{match.group(1)}', s)
127 | except:
128 | # contain number not for instance
129 | return None
130 | return result
131 |
132 |
133 | for ann in ann_list:
134 |
135 | metadata_fn_path = ann['metadata_fn']
136 | img_fn = ann['img_fn']
137 | img_path = os.path.join(self.img_prefix,img_fn)
138 | annotations = json.load(open(os.path.join(self.img_prefix, metadata_fn_path)))
139 | masks = annotations['segms']
140 | bboxes = np.array(annotations['boxes'])
141 |
142 | class_names = ann['objects']
143 | num_objects = len(class_names)
144 | ref_string = ''
145 | for i in range(num_objects):
146 | ref_string = ref_string + f'region{i+1} ' + ','
147 | ref_string = ref_string[:-1]
148 | ref_prefix = random.choice(Ref_WAY)
149 |
150 | begion_string = ref_prefix.replace('', ref_string)
151 | qa_s = []
152 |
153 | q = ann['question_orig']
154 | q = replace_numbers_with_tags(q, class_names)
155 | a = ann['answer_orig']
156 | a = replace_numbers_with_tags(a, class_names)
157 | why = ann['rationale_orig']
158 | why = replace_numbers_with_tags(why, class_names)
159 | if (q is None) or (a is None) or (why) is None:
160 | continue
161 |
162 |
163 | qa_s.append({'from': 'human', 'value': begion_string + q})
164 | qa_s.append({'from': 'gpt', 'value': a})
165 | qa_s.append({'from': 'human', 'value': random.choice(WHY_QUESTIONS)})
166 | qa_s.append({'from': 'gpt', 'value': why})
167 |
168 | data_infos.append(dict(
169 | img_path = img_path,
170 | bboxes = bboxes,
171 | masks = masks,
172 | labels= class_names,
173 | qas = qa_s)
174 | )
175 |
176 |
177 | return data_infos
178 |
179 | def __len__(self):
180 | return len(self.data_infos)
181 |
182 | def __getitem__(self, i):
183 | data_info = self.data_infos[i]
184 |
185 | img_path = data_info['img_path']
186 | masks = data_info['masks']
187 | bboxes = data_info['bboxes']
188 |
189 | qas = data_info['qas']
190 | processor = self.data_args.image_processor
191 | image = Image.open(img_path).convert('RGB')
192 | w, h = image.size
193 | # TODO ablation this
194 |
195 | image_file = img_path
196 |
197 | pred_masks = np.zeros((len(masks), h, w))
198 | for i,mask in enumerate(masks):
199 |
200 | int_box = [round(box) for box in bboxes[i][:-1]]
201 |
202 | height_ = int(int_box[3]-int_box[1])
203 | width_ = int(int_box[2]-int_box[0])
204 | box_mask = make_mask(height_, width_, bboxes[i], mask)
205 |
206 | pred_masks[i, int_box[1]:int_box[3], int_box[0]:int_box[2]] = box_mask
207 |
208 | image = processor.preprocess(image,
209 | do_center_crop=False,
210 | return_tensors='pt')['pixel_values'][0]
211 |
212 | image = torch.nn.functional.interpolate(image.unsqueeze(0),
213 | size=(512, 512),
214 | mode='bilinear',
215 | align_corners=False).squeeze(0)
216 |
217 | cur_token_len = (image.shape[1] // 16) * (image.shape[2] // 16) # FIXME: 16 is hardcoded patch size
218 | qas = copy.deepcopy(qas)
219 | qas[0]['value'] = self.begin_str + qas[0]['value']
220 |
221 | sources = preprocess_multimodal(
222 | copy.deepcopy([qas]),
223 | self.data_args, cur_token_len)
224 |
225 | data_dict = preprocess(
226 | sources,
227 | self.tokenizer,
228 | has_image=True)
229 | if isinstance(i, int):
230 | data_dict = dict(input_ids=data_dict['input_ids'][0],
231 | labels=data_dict['labels'][0])
232 |
233 | data_dict['image'] = image
234 | data_dict['masks'] = torch.Tensor(pred_masks)
235 |
236 | return data_dict
237 |
--------------------------------------------------------------------------------
/osprey/datasets/vg.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import random
3 | import os
4 | import numpy as np
5 | import torch
6 | from .stage2_data import CustomDataset
7 | from osprey.train.train import preprocess, preprocess_multimodal
8 |
9 | LIMIT = " Answer the question using a short phrase."
10 | QUESTIONS = [
11 | 'Give me a short description of .',
12 | 'Can you give me a short description of ?',
13 | 'Can you provide me with a short description of the region in the picture marked by ?',
14 | "I'm curious about the region represented by in the picture. Could you describe it in few words?",
15 | 'What can you tell me about the region indicated by in the image in few words?',
16 | "I'd like to know more about the area in the photo labeled . Can you give me a concise description?",
17 | 'Could you describe the region shown as in the picture concisely?',
18 | 'What can you give me about the region outlined by in the photo?',
19 | 'Please provide me with a brief description of the region marked with in the image.',
20 | 'Can you give me a brief introduction of the region labeled as in the picture?',
21 | "I'm interested in knowing the region represented by in the photo. Can you describe it in several words?",
22 | 'What is the region outlined by in the picture like? Could you give me a streamlined description?',
23 | 'Can you provide me with a brief description of the region in the picture marked by , please?',
24 | "I'm curious about the region represented by in the picture. Could you describe it in few words, please?",
25 | 'What can you tell me about the region indicated by in the image?',
26 | "I'd like to know more about the area in the photo labeled , please. Can you give me a simple description?",
27 | 'Could you describe the region shown as in the picture in several words?',
28 | 'Please provide me with a simple description of the region marked with in the image, please.',
29 | "I'm interested in learning more about the region represented by in the photo. Can you describe it in few words, please?",
30 | 'What is the region outlined by in the picture like, please? Could you give me a simple and clear description?',
31 | 'Please describe the region in the image concisely.',
32 | 'Can you offer a simple analysis of the region in the image?',
33 | 'Could tell me something about the region highlighted by in the picture briefly?',
34 | 'Can you share a simple rundown of the region denoted by in the presented image?'
35 | ]
36 |
37 | class VGDATA(CustomDataset):
38 | def __init__(self,
39 | tokenizer,
40 | data_args=None,
41 | ann_file=None,
42 | img_prefix=None,
43 | max_gt_per_img=3,
44 | ):
45 |
46 | self.data_args = data_args
47 | self.tokenizer = tokenizer
48 | self.ann_file = ann_file
49 | self.img_prefix = img_prefix
50 | self.max_gt_per_img = max_gt_per_img
51 |
52 | super().__init__(tokenizer, data_args, ann_file, img_prefix, max_gt_per_img)
53 |
54 | self.begin_str = """\nThis provides an overview of the picture.\n"""
55 |
56 |
57 | def get_data_item(self, idx):
58 | data_info = self.data_infos[idx]
59 | ann_info = self.get_ann_info(idx)
60 |
61 | img_path = os.path.join(self.img_prefix, data_info['filename'])
62 | image = self.read_process_image(img_path)
63 |
64 | gt_labels = []
65 | gt_masks_ann = []
66 |
67 | for i, ann in enumerate(ann_info):
68 | if ann.get('ignore', False):
69 | continue
70 | mask = self.annToMask(ann['segmentation'], data_info['height'], data_info['width'])
71 |
72 | gt_labels.append(ann['caption'])
73 | gt_masks_ann.append(mask)
74 |
75 |
76 | data_item = dict(
77 | img = image,
78 | gt_labels=gt_labels,
79 | gt_masks=gt_masks_ann
80 | )
81 | return data_item
82 |
83 |
84 | def process_text(self, data_item):
85 | image = data_item['img']
86 | ori_labels = data_item['gt_labels']
87 | ori_masks = np.array(data_item['gt_masks'])
88 | ori_masks = torch.from_numpy(ori_masks)
89 |
90 | shuffle_ids = torch.randperm(len(ori_labels))
91 | if len(shuffle_ids) > self.max_gt_per_img:
92 | shuffle_ids = shuffle_ids[:self.max_gt_per_img]
93 | ori_masks = ori_masks[shuffle_ids]
94 | ori_labels = [ori_labels[i] for i in shuffle_ids]
95 |
96 | sources = dict()
97 |
98 | sources['conversations'] = []
99 |
100 | for i in range(len(ori_labels)):
101 | question = random.choice(QUESTIONS).strip()
102 | question = question.replace('', '')
103 | if i == 0:
104 | question = self.begin_str + question
105 | question += LIMIT
106 | answer = ori_labels[i]
107 | sources['conversations'].append(
108 | {'from': 'human', 'value': question})
109 | sources['conversations'].append({'from': 'gpt', 'value': answer})
110 |
111 | cur_token_len = (image.shape[1] // 16) * (image.shape[2] // 16)
112 |
113 | sources = preprocess_multimodal(
114 | copy.deepcopy([sources['conversations']]),
115 | self.data_args,
116 | cur_token_len)
117 | # print(sources)
118 |
119 | data_dict = preprocess(
120 | sources,
121 | self.tokenizer,
122 | has_image=True
123 | )
124 |
125 | # get single
126 | if isinstance(i, int):
127 | data_dict = dict(input_ids=data_dict['input_ids'][0],
128 | labels=data_dict['labels'][0])
129 |
130 | data_dict['image'] = image
131 | data_dict['masks'] = ori_masks
132 | return data_dict
133 |
--------------------------------------------------------------------------------
/osprey/eval/README.md:
--------------------------------------------------------------------------------
1 | # Evaluation for Osprey 🔎
2 |
3 | This document provides instructions on evaluating Osprey on four representative tasks, including open-vocabulary segmentation, referring object classification, detailed region description and region level captioning.
4 |
5 | We have developed two types of models:the first is [Osprey](https://huggingface.co/sunshine-lwt/Osprey-7b/tree/main), the second is [Osprey-Chat](https://huggingface.co/sunshine-lwt/Osprey-Chat-7b/tree/main)(denote `Osprey*` in our paper). Osprey-Chat exhibits better conversation and image-level understanding&reasoning capabilities with additional llava data(llava_v1_5_mix665k.json).
6 |
7 | ## 1. Open-Vocabulary Segmentation
8 | - Download [SentenceBERT model](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2), which is used for calculating the semantic similarity.
9 | - The evaluation is based on `detectron2`, please install the following dependences.
10 | ```
11 | git clone https://github.com/facebookresearch/detectron2.git
12 | python -m pip install -e detectron2
13 | pip install git+https://github.com/cocodataset/panopticapi.git
14 | pip install git+https://github.com/mcordts/cityscapesScripts.git
15 | ```
16 | - Prepare datasets, please refer to [Data preparation](./datasets/README.md).
17 |
18 | ### Cityscapes
19 | ```
20 | cd osprey/eval
21 | python eval_open_vocab_seg_detectron2.py --dataset cityscapes --model path/to/osprey-7b --bert path/to/all-MiniLM-L6-v2
22 | ```
23 | ### Ade20K
24 | ```
25 | cd osprey/eval
26 | python eval_open_vocab_seg_detectron2.py --dataset ade --model path/to/osprey-7b --bert path/to/all-MiniLM-L6-v2
27 | ```
28 |
29 |
30 |
31 |

32 |
33 |
34 |
35 | ## 2. Referring Object Classification
36 |
37 | ### LVIS
38 | - Download our generated [lvis_val_1k_category.json](https://huggingface.co/datasets/sunshine-lwt/Osprey-ValData/resolve/main/lvis_val_1k_category.json?download=true) (We randomly sample 1K images with 4,004 objects from LVIS dataset.)
39 | ```
40 | cd osprey/eval
41 | python lvis_paco_eval.py --model path/to/osprey-7b --bert path/to/all-MiniLM-L6-v2 --img path/to/coco-all-imgs --json lvis_val_1k_category.json
42 | ```
43 | ### PACO
44 | - Download our generated [paco_val_1k_category.json](https://huggingface.co/datasets/sunshine-lwt/Osprey-ValData/resolve/main/paco_val_1k_category.json?download=true) (We randomly sample 1K images with 4,263 objects from PACO dataset.)
45 | ```
46 | cd osprey/eval
47 | python lvis_paco_eval.py --model path/to/osprey-7b --bert path/to/all-MiniLM-L6-v2 --img path/to/coco-all-imgs --json paco_val_1k_category.json
48 | ```
49 |
50 |
51 |

52 |
53 |
54 | ## 3. Detailed Region Description
55 | - Fill in the gpt interface in `eval_gpt.py`.
56 | - Change the path in `gpt_eval.sh`.
57 | ```
58 | cd osprey/eval
59 | sh gpt_eval.sh
60 | ```
61 |
62 |
63 |

64 |
65 |
66 | ## 4. Ferret-Bench
67 |
68 | Note that we have converted the boxes in `box_refer_caption.json` and `box_refer_reason.json` to polygon format denoted by `segmentation`.
69 | ### Referring Description
70 |
71 | ```
72 | cd osprey/eval
73 | python ferret_bench_eval.py --model_name path/to/osprey-chat-7b --root_path path/to/coco_imgs --json_path ./ferret_bench/box_refer_caption.json
74 | ```
75 |
76 | ### Referring Reasoning
77 |
78 | ```
79 | cd osprey/eval
80 | python ferret_bench_eval.py --model_name path/to/osprey-chat-7b --root_path path/to/coco_imgs --json_path ./ferret_bench/box_refer_reason.json
81 | ```
82 |
83 | Then use GPT-4 to evaluate the result as in [Ferret](https://github.com/apple/ml-ferret).
84 |
85 |
86 |

87 |
88 |
89 | ## 5. POPE
90 |
91 | - Download coco from POPE and put under osprey/eval/pope.
92 | - Change the path in `pope_eval.sh`.
93 |
94 | ```python
95 | cd osprey/eval
96 | sh pope_eval.sh
97 | ```
98 |
99 |
100 |

101 |
102 |
103 | ## 6. Region Level Captioning
104 |
105 | - We fine-tune Osprey-7B on training set of RefCOCOg. The fintuned model can be found in [Osprey-7B-refcocog-fintune](https://huggingface.co/sunshine-lwt/Osprey-7b-Refercocog-finetuning/tree/main).
106 | - Download [finetune_refcocog_val_with_mask.json](https://huggingface.co/datasets/sunshine-lwt/Osprey-ValData/resolve/main/finetune_refcocog_val_with_mask.json?download=true).
107 | - Generate output json files:
108 |
109 | ```
110 | cd osprey/eval
111 | python refcocog_eval.py --model path/to/Osprey-7B-refcocog-fintune --img path/to/coco-all-imgs --json finetune_refcocog_val_with_mask.json
112 | ```
113 | - Finally, evaluate the output json file using `CaptionMetrics`.
114 |
115 |
116 |

117 |
118 |
--------------------------------------------------------------------------------
/osprey/eval/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Data preparation for Open-Vocabulary Segmentation☕️
2 | Dataset preparation follows [Detectron2](https://github.com/facebookresearch/detectron2/blob/main/datasets/README.md) and [Mask2Former](https://github.com/facebookresearch/Mask2Former/blob/main/datasets/README.md).
3 |
4 | The datasets are assumed to exist in a directory specified by the environment variable `DETECTRON2_DATASETS`. Under this directory, detectron2 will look for datasets in the structure described below, if needed.
5 | ```
6 | $DETECTRON2_DATASETS/
7 | ADEChallengeData2016/
8 | cityscapes/
9 | ```
10 | You can set the location for builtin datasets by export `DETECTRON2_DATASETS=/path/to/datasets`. The default is `./datasets` under the eval directory.
11 |
12 | ## ADE20k (A-150)
13 | Dataset structure:
14 | ```
15 | ADEChallengeData2016/
16 | images/
17 | annotations/
18 | objectInfo150.txt
19 | # download instance annotation
20 | annotations_instance/
21 | # generated by prepare_ade20k_sem_seg.py
22 | annotations_detectron2/
23 | # below are generated by prepare_ade20k_pan_seg.py
24 | ade20k_panoptic_{train,val}.json
25 | ade20k_panoptic_{train,val}/
26 | # below are generated by prepare_ade20k_ins_seg.py
27 | ade20k_instance_{train,val}.json
28 | ```
29 | The directory annotations_detectron2 is generated by running `python datasets/prepare_ade20k_sem_seg.py`.
30 |
31 | Download the instance annotation from http://sceneparsing.csail.mit.edu/:
32 | ```
33 | wget http://sceneparsing.csail.mit.edu/data/ChallengeData2017/annotations_instance.tar
34 | ```
35 | Then, run `python datasets/prepare_ade20k_pan_seg.py`, to combine semantic and instance annotations for panoptic annotations.
36 |
37 | Finally, run `python datasets/prepare_ade20k_ins_seg.py`, to extract instance annotations in COCO format.
38 |
39 | ## Cityscapes
40 | Data structure:
41 | ```
42 | cityscapes/
43 | gtFine/
44 | train/
45 | aachen/
46 | color.png, instanceIds.png, labelIds.png, polygons.json,
47 | labelTrainIds.png
48 | ...
49 | val/
50 | test/
51 | # below are generated Cityscapes panoptic annotation
52 | cityscapes_panoptic_train.json
53 | cityscapes_panoptic_train/
54 | cityscapes_panoptic_val.json
55 | cityscapes_panoptic_val/
56 | cityscapes_panoptic_test.json
57 | cityscapes_panoptic_test/
58 | leftImg8bit/
59 | train/
60 | val/
61 | test/
62 | ```
63 | Install cityscapes scripts by:
64 | ```
65 | pip install git+https://github.com/mcordts/cityscapesScripts.git
66 | ```
67 |
68 | Note: to create labelTrainIds.png, first prepare the above structure, then run cityscapesescript with:
69 | ```
70 | CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createTrainIdLabelImgs.py
71 | ```
72 |
73 | Note: to generate Cityscapes panoptic dataset, run cityscapesescript with:
74 |
75 | ```
76 | CITYSCAPES_DATASET=/path/to/abovementioned/cityscapes python cityscapesscripts/preparation/createPanopticImgs.py
77 | ```
78 |
--------------------------------------------------------------------------------
/osprey/eval/datasets/ade20k_instance_catid_mapping.txt:
--------------------------------------------------------------------------------
1 | Instacne100 SceneParse150 FullADE20K
2 | 1 8 165
3 | 2 9 3055
4 | 3 11 350
5 | 4 13 1831
6 | 5 15 774
7 | 5 15 783
8 | 6 16 2684
9 | 7 19 687
10 | 8 20 471
11 | 9 21 401
12 | 10 23 1735
13 | 11 24 2473
14 | 12 25 2329
15 | 13 28 1564
16 | 14 31 57
17 | 15 32 2272
18 | 16 33 907
19 | 17 34 724
20 | 18 36 2985
21 | 18 36 533
22 | 19 37 1395
23 | 20 38 155
24 | 21 39 2053
25 | 22 40 689
26 | 23 42 266
27 | 24 43 581
28 | 25 44 2380
29 | 26 45 491
30 | 27 46 627
31 | 28 48 2388
32 | 29 50 943
33 | 30 51 2096
34 | 31 54 2530
35 | 32 56 420
36 | 33 57 1948
37 | 34 58 1869
38 | 35 59 2251
39 | 36 63 239
40 | 37 65 571
41 | 38 66 2793
42 | 39 67 978
43 | 40 68 236
44 | 41 70 181
45 | 42 71 629
46 | 43 72 2598
47 | 44 73 1744
48 | 45 74 1374
49 | 46 75 591
50 | 47 76 2679
51 | 48 77 223
52 | 49 79 47
53 | 50 81 327
54 | 51 82 2821
55 | 52 83 1451
56 | 53 84 2880
57 | 54 86 480
58 | 55 87 77
59 | 56 88 2616
60 | 57 89 246
61 | 57 89 247
62 | 58 90 2733
63 | 59 91 14
64 | 60 93 38
65 | 61 94 1936
66 | 62 96 120
67 | 63 98 1702
68 | 64 99 249
69 | 65 103 2928
70 | 66 104 2337
71 | 67 105 1023
72 | 68 108 2989
73 | 69 109 1930
74 | 70 111 2586
75 | 71 112 131
76 | 72 113 146
77 | 73 116 95
78 | 74 117 1563
79 | 75 119 1708
80 | 76 120 103
81 | 77 121 1002
82 | 78 122 2569
83 | 79 124 2833
84 | 80 125 1551
85 | 81 126 1981
86 | 82 127 29
87 | 83 128 187
88 | 84 130 747
89 | 85 131 2254
90 | 86 133 2262
91 | 87 134 1260
92 | 88 135 2243
93 | 89 136 2932
94 | 90 137 2836
95 | 91 138 2850
96 | 92 139 64
97 | 93 140 894
98 | 94 143 1919
99 | 95 144 1583
100 | 96 145 318
101 | 97 147 2046
102 | 98 148 1098
103 | 99 149 530
104 | 100 150 954
--------------------------------------------------------------------------------
/osprey/eval/datasets/prepare_ade20k_ins_seg.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/facebookresearch/Mask2Former/blob/main/datasets/prepare_ade20k_ins_seg.py
3 | """
4 |
5 | import glob
6 | import json
7 | import os
8 | from collections import Counter
9 |
10 | import numpy as np
11 | import tqdm
12 | from panopticapi.utils import IdGenerator, save_json
13 | from PIL import Image
14 | import pycocotools.mask as mask_util
15 |
16 |
17 | if __name__ == "__main__":
18 | dataset_dir = os.getenv("DETECTRON2_DATASETS", "datasets")
19 |
20 | for name, dirname in [("train", "training"), ("val", "validation")]:
21 | image_dir = os.path.join(dataset_dir, f"ADEChallengeData2016/images/{dirname}/")
22 | instance_dir = os.path.join(
23 | dataset_dir, f"ADEChallengeData2016/annotations_instance/{dirname}/"
24 | )
25 |
26 | # img_id = 0
27 | ann_id = 1
28 |
29 | # json
30 | out_file = os.path.join(dataset_dir, f"ADEChallengeData2016/ade20k_instance_{name}.json")
31 |
32 | # json config
33 | instance_config_file = "datasets/ade20k_instance_imgCatIds.json"
34 | with open(instance_config_file) as f:
35 | category_dict = json.load(f)["categories"]
36 |
37 | # load catid mapping
38 | # it is important to share category id for both instance and panoptic annotations
39 | mapping_file = "datasets/ade20k_instance_catid_mapping.txt"
40 | with open(mapping_file) as f:
41 | map_id = {}
42 | for i, line in enumerate(f.readlines()):
43 | if i == 0:
44 | continue
45 | ins_id, sem_id, _ = line.strip().split()
46 | # shift id by 1 because we want it to start from 0!
47 | # ignore_label becomes 255
48 | map_id[int(ins_id)] = int(sem_id) - 1
49 |
50 | for cat in category_dict:
51 | cat["id"] = map_id[cat["id"]]
52 |
53 | filenames = sorted(glob.glob(os.path.join(image_dir, "*.jpg")))
54 |
55 | ann_dict = {}
56 | images = []
57 | annotations = []
58 |
59 | for idx, filename in enumerate(tqdm.tqdm(filenames)):
60 | image = {}
61 | image_id = os.path.basename(filename).split(".")[0]
62 |
63 | image["id"] = image_id
64 | image["file_name"] = os.path.basename(filename)
65 |
66 | original_format = np.array(Image.open(filename))
67 | image["width"] = original_format.shape[1]
68 | image["height"] = original_format.shape[0]
69 |
70 | images.append(image)
71 |
72 | filename_instance = os.path.join(instance_dir, image_id + ".png")
73 | ins_seg = np.asarray(Image.open(filename_instance))
74 | assert ins_seg.dtype == np.uint8
75 |
76 | instance_cat_ids = ins_seg[..., 0]
77 | # instance id starts from 1!
78 | # because 0 is reserved as VOID label
79 | instance_ins_ids = ins_seg[..., 1]
80 |
81 | # process things
82 | for thing_id in np.unique(instance_ins_ids):
83 | if thing_id == 0:
84 | continue
85 | mask = instance_ins_ids == thing_id
86 | instance_cat_id = np.unique(instance_cat_ids[mask])
87 | assert len(instance_cat_id) == 1
88 |
89 | anno = {}
90 | anno['id'] = ann_id
91 | ann_id += 1
92 | anno['image_id'] = image['id']
93 | anno["iscrowd"] = int(0)
94 | anno["category_id"] = int(map_id[instance_cat_id[0]])
95 |
96 | inds = np.nonzero(mask)
97 | ymin, ymax = inds[0].min(), inds[0].max()
98 | xmin, xmax = inds[1].min(), inds[1].max()
99 | anno["bbox"] = [int(xmin), int(ymin), int(xmax - xmin + 1), int(ymax - ymin + 1)]
100 | # if xmax <= xmin or ymax <= ymin:
101 | # continue
102 | rle = mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0]
103 | rle["counts"] = rle["counts"].decode("utf-8")
104 | anno["segmentation"] = rle
105 | anno["area"] = int(mask_util.area(rle))
106 | annotations.append(anno)
107 |
108 | # save this
109 | ann_dict['images'] = images
110 | ann_dict['categories'] = category_dict
111 | ann_dict['annotations'] = annotations
112 |
113 | save_json(ann_dict, out_file)
--------------------------------------------------------------------------------
/osprey/eval/datasets/prepare_ade20k_sem_seg.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/facebookresearch/Mask2Former/blob/main/datasets/prepare_ade20k_ins_seg.py
3 | """
4 |
5 | import os
6 | from pathlib import Path
7 |
8 | import numpy as np
9 | import tqdm
10 | from PIL import Image
11 |
12 |
13 | def convert(input, output):
14 | img = np.asarray(Image.open(input))
15 | assert img.dtype == np.uint8
16 | img = img - 1 # 0 (ignore) becomes 255. others are shifted by 1
17 | Image.fromarray(img).save(output)
18 |
19 |
20 | if __name__ == "__main__":
21 | dataset_dir = (
22 | Path(os.getenv("DETECTRON2_DATASETS", "datasets")) / "ade" / "ADEChallengeData2016"
23 | )
24 | for name in ["training", "validation"]:
25 | annotation_dir = dataset_dir / "annotations" / name
26 | output_dir = dataset_dir / "annotations_detectron2" / name
27 | output_dir.mkdir(parents=True, exist_ok=True)
28 | for file in tqdm.tqdm(list(annotation_dir.iterdir())):
29 | output_file = output_dir / file.name
30 | convert(file, output_file)
--------------------------------------------------------------------------------
/osprey/eval/eval_gpt.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/eval_gpt_review.py
3 | """
4 |
5 | import argparse
6 | import json
7 | import os
8 |
9 | import openai
10 | import time
11 | from tqdm import tqdm
12 | import requests
13 |
14 | def get_eval(content: str, max_tokens: int):
15 | while True:
16 | try:
17 | messages=[{
18 | 'role': 'system',
19 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
20 | }, {
21 | 'role': 'user',
22 | 'content': content,
23 | }]
24 | ##########
25 |
26 | # change youre gpt interface here
27 | # ret = gpt_answer
28 |
29 | ##########
30 | break
31 |
32 | except openai.error.RateLimitError:
33 | pass
34 | except Exception as e:
35 | print(e)
36 | time.sleep(1)
37 |
38 | return ret
39 |
40 |
41 | def parse_score(review):
42 | try:
43 | score_pair = review.split('\n')[0]
44 | score_pair = score_pair.replace(',', ' ')
45 | sp = score_pair.split(' ')
46 | if len(sp) == 2:
47 | return [float(sp[0]), float(sp[1])]
48 | else:
49 | print('error', review)
50 | return [-1, -1]
51 | except Exception as e:
52 | print(e)
53 | print('error', review)
54 | return [-1, -1]
55 |
56 |
57 | if __name__ == '__main__':
58 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
59 | parser.add_argument('--question', help='path to question file')
60 | parser.add_argument('--context', help='path to gpt prompt file')
61 | parser.add_argument('--answer-list', nargs='+', default=[], help='gpt answer and model answer json files')
62 | parser.add_argument('--rule', help='gpt rule')
63 | parser.add_argument('--output', help='output json dir')
64 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
65 | args = parser.parse_args()
66 |
67 | f_q = json.load(open(os.path.expanduser(args.question)))
68 | f_ans1 = json.load(open(os.path.expanduser(args.answer_list[0])))
69 | f_ans2 = json.load(open(os.path.expanduser(args.answer_list[1])))
70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
71 |
72 | os.makedirs('./result', exist_ok=True)
73 |
74 | if os.path.isfile(os.path.expanduser(args.output)):
75 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
76 | else:
77 | cur_reviews = []
78 |
79 | review_file = open(f'{args.output}', 'a')
80 |
81 | context_list = json.load(open(os.path.expanduser(args.context)))
82 |
83 | image_to_context = {context['image']: context for context in context_list}
84 |
85 | handles = []
86 | idx = 0
87 |
88 | for ques, ans1, ans2 in tqdm(zip(f_q, f_ans1, f_ans2)):
89 |
90 | inst = image_to_context[ques['image']]
91 |
92 | category = ques['category']
93 | if category in rule_dict:
94 | rule = rule_dict[category]
95 | else:
96 | assert False, f"category not found in rule file: {category}."
97 |
98 | prompt = rule['prompt']
99 | role = rule['role']
100 | content = (f'[Context]\{inst["prompt"]}\n\n'
101 | f'[Question]\n{ques["text"]}\n\n'
102 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
103 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
104 | f'[System]\n{prompt}\n\n')
105 |
106 | cur_js = {
107 | 'id': idx+1,
108 | 'question_id': ques['question_id'],
109 | 'answer1_id': ans1.get('answer_id', ans1['question_id']),
110 | 'answer2_id': ans2.get('answer_id', ans2['question_id']),
111 | 'category': category
112 | }
113 | if idx >= len(cur_reviews):
114 | review = get_eval(content, args.max_tokens)
115 | print(review)
116 |
117 | scores = parse_score(review)
118 | cur_js['content'] = review
119 | cur_js['tuple'] = scores
120 | cur_js['answer1'] = ans1["text"]
121 | cur_js['answer2'] = ans2["text"]
122 | review_file.write(json.dumps(cur_js) + '\n')
123 | review_file.flush()
124 | else:
125 | print(f'Skipping {idx} as we already have it.')
126 |
127 | idx += 1
128 | print(idx)
129 |
130 | review_file.close()
--------------------------------------------------------------------------------
/osprey/eval/ferret_bench_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import shortuuid
7 | from torch import nn
8 | import copy
9 | from functools import partial
10 | from transformers import AutoTokenizer, CLIPImageProcessor
11 | from osprey.constants import IMAGE_TOKEN_INDEX
12 | from osprey.conversation import conv_templates, SeparatorStyle
13 | from osprey.utils import disable_torch_init
14 | from osprey.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
15 | from osprey.train.train import preprocess, preprocess_multimodal
16 | from osprey.train.train import DataArguments
17 | from osprey.model.language_model.osprey_llama import OspreyLlamaForCausalLM
18 | from pycocotools import mask as maskUtils
19 | import numpy as np
20 | from PIL import Image
21 | import cv2
22 | import re
23 | data_args = DataArguments()
24 | data_args.mm_use_im_start_end = False
25 | data_args.is_multimodal = True
26 |
27 | def annToMask(ann, h, w):
28 | rles = maskUtils.frPyObjects(ann, h, w)
29 | rle = maskUtils.merge(rles)
30 | m = maskUtils.decode(rle)
31 | return m
32 |
33 | class GPT_EVAL(nn.Module):
34 | def __init__(self, model_path, model_base=None):
35 | super().__init__()
36 | disable_torch_init()
37 | model_path = os.path.expanduser(model_path)
38 |
39 | self.tokenizer = AutoTokenizer.from_pretrained(
40 | model_path,
41 | model_max_length=2048,
42 | padding_side="right",
43 | use_fast=True
44 | )
45 | self.model = OspreyLlamaForCausalLM.from_pretrained(
46 | model_path,
47 | torch_dtype=torch.bfloat16,
48 | ).cuda()
49 | self.tokenizer.pad_token = self.tokenizer.unk_token
50 |
51 | self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":512}, resample=3, do_center_crop=True, crop_size={"height": 512, "width": 512},
52 | do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
53 | image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
54 |
55 | spi_tokens = ['', '']
56 | self.tokenizer.add_tokens(spi_tokens, special_tokens=True)
57 |
58 | for m in self.model.modules():
59 | m.tokenizer = self.tokenizer
60 |
61 | vision_tower = self.model.get_vision_tower()
62 | if not vision_tower.is_loaded:
63 | vision_tower.load_model()
64 | vision_tower.to(dtype=torch.float16, device='cuda')
65 |
66 |
67 | def forward(self, root_path, ann_file):
68 | final = []
69 | anns = json.load(open(ann_file))
70 |
71 | for i, ann in enumerate(anns):
72 | print(i)
73 |
74 | model_answer = {}
75 | model_answer["question_id"] = ann["question_id"]
76 | model_answer["image"] = ann["image"]
77 | model_answer["category"] = ann["category"]
78 | img_path = os.path.join(root_path, ann['image'])
79 | img = cv2.imread(img_path)
80 | question = ann['text']
81 |
82 | question = re.sub(r'', r'', question)
83 | # question += 'Answer the question in detail.'
84 | idx = 1
85 |
86 | mask_r = ann['annotation']['segmentation']
87 | height, width = img.shape[:2]
88 |
89 | if isinstance(mask_r, list):
90 | mask = annToMask(mask_r, height, width)
91 | else:
92 | mask = maskUtils.decode(mask_r)
93 | mask = torch.from_numpy(mask).unsqueeze(0)
94 |
95 | x1, y1, w, h = ann['annotation']['bbox']
96 | bbox = np.array([x1, y1, x1 + w, y1 + h])
97 | bbox = torch.from_numpy(bbox)
98 |
99 |
100 | init_inputs = get_init_inputs(img_path,
101 | self.image_processor,
102 | self.tokenizer,
103 | pred_bboxes=bbox,
104 | mask=mask,
105 | question=question,
106 | round_ids=0,
107 | last_round_source={},
108 | )
109 |
110 | masks = init_inputs['masks'].cuda()
111 | image = init_inputs['image']
112 | conv = conv_templates['osprey_v1'].copy()
113 | qs = init_inputs['sources'][0][0]['value']
114 | conv.append_message(conv.roles[0], qs)
115 | conv.append_message(conv.roles[1], None)
116 | prompt = conv.get_prompt()
117 |
118 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
119 |
120 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
121 | keywords = [stop_str]
122 | stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids)
123 |
124 | self.model.model.tokenizer = self.tokenizer
125 |
126 |
127 | with torch.inference_mode():
128 |
129 | self.model.orig_forward = self.model.forward
130 | self.model.forward = partial(self.model.orig_forward,
131 | img_metas=[None],
132 | masks=[masks.half()])
133 |
134 | output_ids = self.model.generate(
135 | input_ids,
136 | images=image.unsqueeze(0).half().cuda(),
137 | do_sample=True,
138 | # masks=[masks.half()],
139 | temperature=0.2,
140 | max_new_tokens=1024,
141 | use_cache=True,
142 | num_beams=1,
143 | # stopping_criteria=[stopping_criteria]
144 | )
145 |
146 | self.model.forward = self.model.orig_forward
147 |
148 | input_token_len = input_ids.shape[1]
149 | n_diff_input_output = (
150 | input_ids != output_ids[:, :input_token_len]).sum().item()
151 | if n_diff_input_output > 0:
152 | print(
153 | f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
154 | outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
155 | skip_special_tokens=True)[0]
156 |
157 | outputs = outputs.strip()
158 | if outputs.endswith(stop_str):
159 | outputs = outputs[:-len(stop_str)]
160 | outputs = outputs.strip()
161 |
162 | model_answer['text'] = outputs
163 |
164 | print(outputs)
165 | final.append(model_answer)
166 |
167 | final_ = json.dumps(final)
168 | with open('ferret_bench/osprey_refer_reason_original_3.json','w') as fw:
169 | fw.write(final_)
170 | fw.close()
171 |
172 |
173 | def get_init_inputs(img_path,
174 | processor,
175 | tokenizer,
176 | pred_bboxes,
177 | mask,
178 | question=None,
179 | round_ids=0,
180 | last_round_source=None):
181 |
182 | if round_ids == 0:
183 |
184 | image = Image.open(img_path).convert('RGB')
185 |
186 | image = processor.preprocess(image,
187 | do_center_crop=False,
188 | return_tensors='pt')['pixel_values'][0]
189 |
190 | image = torch.nn.functional.interpolate(image.unsqueeze(0),
191 | size=(512, 512),
192 | mode='bilinear',
193 | align_corners=False).squeeze(0)
194 |
195 | else:
196 | image = last_round_source['image']
197 |
198 | cur_token_len = (image.shape[1] // 16) * (image.shape[2] // 16)
199 |
200 | mask = mask.to(image.device)
201 |
202 | begin_str = """.\nThis provides an overview of the picture.\n"""
203 |
204 | sources = dict()
205 | sources['conversations'] = []
206 |
207 | sources['conversations'].append({'from': 'human', 'value': begin_str+question})
208 |
209 | sources = preprocess_multimodal([sources['conversations']], data_args, cur_token_len)
210 |
211 | data_dict = {}
212 | data_dict['sources'] = sources
213 | data_dict['image'] = image
214 | data_dict['bboxes'] = pred_bboxes
215 | data_dict['masks'] = mask
216 | data_dict['img_metas'] = dict(filename=img_path)
217 |
218 | return data_dict
219 |
220 |
221 | if __name__ == "__main__":
222 | model_name = '/Osprey-Chat-7b'
223 | root_path = '/path/to/coco-imgs'
224 | json_path = './ferret_bench/refer_reason/box_refer_reason.json'
225 | ferret_eval = GPT_EVAL(model_name)
226 | ferret_eval(root_path, json_path)
227 |
228 |
--------------------------------------------------------------------------------
/osprey/eval/gpt_eval.sh:
--------------------------------------------------------------------------------
1 | NAME='osprey'
2 | TYPE='description'
3 |
4 | python osprey_generate_gpt_description_answer.py\
5 | --model /path/to/osprey_7b\
6 | --coco-img /path/to/coco_imgs\
7 | --json ${TYPE}/questions.json
8 |
9 | python eval_gpt.py\
10 | --question ${TYPE}/questions.json\
11 | --context ${TYPE}/prompt.json\
12 | --answer-list ${TYPE}/answers.json\
13 | ${TYPE}/${NAME}_answer.json\
14 | --rule rule.json\
15 | --output result/gpt_score_${NAME}_${TYPE}.jsonl
16 |
17 | python summarize_gpt_score.py --dir result
18 |
19 |
--------------------------------------------------------------------------------
/osprey/eval/lvis_paco_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | from functools import partial
7 | from transformers import AutoTokenizer, CLIPImageProcessor
8 | from osprey.constants import IMAGE_TOKEN_INDEX
9 | from osprey.conversation import conv_templates, SeparatorStyle
10 | from osprey.mm_utils import tokenizer_image_token
11 | from osprey.train.train import preprocess_multimodal
12 | from osprey.train.train import DataArguments
13 | from osprey.model.language_model.osprey_llama import OspreyLlamaForCausalLM
14 | from pycocotools import mask as maskUtils
15 | import numpy as np
16 | from PIL import Image
17 | from sentence_transformers import SentenceTransformer, util
18 | import argparse
19 |
20 | data_args = DataArguments()
21 | data_args.mm_use_im_start_end = False
22 | data_args.is_multimodal = True
23 |
24 | def annToMask(ann, h, w):
25 | rles = maskUtils.frPyObjects(ann, h, w)
26 | rle = maskUtils.merge(rles)
27 | m = maskUtils.decode(rle)
28 | return m
29 |
30 | class LVIS_PACO_EVAL():
31 | def __init__(self, model_path, bert_model):
32 | model_path = os.path.expanduser(model_path)
33 |
34 | self.tokenizer = AutoTokenizer.from_pretrained(
35 | model_path,
36 | model_max_length=2048,
37 | padding_side="right",
38 | use_fast=True
39 | )
40 | self.model = OspreyLlamaForCausalLM.from_pretrained(
41 | model_path,
42 | torch_dtype=torch.bfloat16,
43 | ).cuda()
44 | self.tokenizer.pad_token = self.tokenizer.unk_token
45 |
46 | self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":512}, resample=3, do_center_crop=True, crop_size={"height": 512, "width": 512},
47 | do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
48 | image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
49 |
50 | spi_tokens = ['', '']
51 | self.tokenizer.add_tokens(spi_tokens, special_tokens=True)
52 |
53 | for m in self.model.modules():
54 | m.tokenizer = self.tokenizer
55 |
56 | vision_tower = self.model.get_vision_tower()
57 | if not vision_tower.is_loaded:
58 | vision_tower.load_model()
59 | vision_tower.to(dtype=torch.float16, device='cuda')
60 |
61 | self.bert_model = SentenceTransformer(bert_model)
62 |
63 |
64 | def eval(self, root_path, ann_file):
65 | data_all = json.load(open(ann_file))
66 | all_sim = 0
67 | all_num = 0
68 | all_iou = 0
69 | for data in tqdm(data_all):
70 | img_path = os.path.join(root_path, data['file_name'])
71 | height = data['height']
72 | width = data['width']
73 | round_ids = 0
74 | last_source = dict()
75 | for i in range(len(data['categories'])):
76 | category = data['categories'][i].replace('_', ' ')
77 | category = category.replace(':', ' ')
78 |
79 | mask_r = data['annotations'][i]['segmentation']
80 |
81 | if isinstance(mask_r, list):
82 | mask = annToMask(mask_r, height, width)
83 | else:
84 | mask = maskUtils.decode(mask_r)
85 | mask = torch.from_numpy(mask).unsqueeze(0)
86 |
87 | init_inputs = get_init_inputs(img_path,
88 | self.image_processor,
89 | mask=mask,
90 | round_ids=round_ids,
91 | last_round_source=last_source,
92 | )
93 |
94 | round_ids += 1
95 | last_source = init_inputs
96 |
97 | image = init_inputs['image']
98 | masks = init_inputs['masks'].cuda()
99 |
100 | conv = conv_templates['osprey_v1'].copy()
101 | qs = init_inputs['sources'][0][0]['value']
102 |
103 | conv.append_message(conv.roles[0], qs)
104 | conv.append_message(conv.roles[1], None)
105 | prompt = conv.get_prompt()
106 |
107 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
108 |
109 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
110 |
111 | self.model.model.tokenizer = self.tokenizer
112 |
113 | with torch.inference_mode():
114 |
115 | self.model.orig_forward = self.model.forward
116 | self.model.forward = partial(self.model.orig_forward,
117 | img_metas=[None],
118 | masks=[masks.half()])
119 |
120 | output_ids = self.model.generate(
121 | input_ids,
122 | images=image.unsqueeze(0).half().cuda(),
123 | do_sample=True,
124 | temperature=0.2,
125 | max_new_tokens=1024,
126 | use_cache=True,
127 | num_beams=1,
128 | )
129 |
130 | self.model.forward = self.model.orig_forward
131 |
132 | input_token_len = input_ids.shape[1]
133 | n_diff_input_output = (
134 | input_ids != output_ids[:, :input_token_len]).sum().item()
135 | if n_diff_input_output > 0:
136 | print(
137 | f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
138 | outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
139 | skip_special_tokens=True)[0]
140 |
141 | outputs = outputs.strip()
142 | if outputs.endswith(stop_str):
143 | outputs = outputs[:-len(stop_str)]
144 | outputs = outputs.strip()
145 | if ':' in outputs:
146 | outputs = outputs.split(':')[1]
147 |
148 | outputs = outputs.replace('.', ' ')
149 | outputs = outputs.replace(':', ' ')
150 | outputs = outputs.replace(',', ' ')
151 |
152 | print("[prediction]: ", outputs)
153 | print("[gt category]:", category)
154 |
155 | outputs_embeddings = self.bert_model.encode(outputs, convert_to_tensor=True)
156 | class_sentence_embeddings = self.bert_model.encode(category, convert_to_tensor=True)
157 | cosine_scores = util.cos_sim(outputs_embeddings, class_sentence_embeddings)
158 |
159 | semantic_iou = SemanticIOU(outputs.lower(), category.lower())
160 |
161 | all_sim += cosine_scores[0][0]
162 | all_iou += semantic_iou
163 | all_num += 1
164 |
165 | print("sim:{}, iou:{}".format(all_sim/all_num, all_iou/all_num))
166 |
167 | print("final sim:{}, semantic iou:{}".format(all_sim/all_num, all_iou/all_num))
168 |
169 |
170 | def SemanticIOU(value: list[str], target: list[str]) -> None:
171 |
172 | intersection = len(set(value.split()) & set(target.split()))
173 | union = len(set(value.split()) | set(target.split()))
174 |
175 | return intersection / union
176 |
177 | def get_init_inputs(img_path,
178 | processor,
179 | mask,
180 | round_ids=0,
181 | last_round_source=None):
182 |
183 | if round_ids == 0:
184 |
185 | image = Image.open(img_path).convert('RGB')
186 |
187 | image = processor.preprocess(image,
188 | do_center_crop=False,
189 | return_tensors='pt')['pixel_values'][0]
190 |
191 | image = torch.nn.functional.interpolate(image.unsqueeze(0),
192 | size=(512, 512),
193 | mode='bilinear',
194 | align_corners=False).squeeze(0)
195 |
196 | else:
197 | image = last_round_source['image']
198 |
199 | cur_token_len = (image.shape[1] // 16) * (image.shape[2] // 16)
200 |
201 | mask = mask.to(image.device)
202 |
203 | begin_str = """\nThis provides an overview of the picture.\n"""
204 |
205 | sources = dict()
206 | sources['conversations'] = []
207 |
208 | question = 'What is the category of ? Using only one word or phrase.'
209 |
210 | sources['conversations'].append({'from': 'human', 'value': begin_str+question})
211 |
212 | sources = preprocess_multimodal([sources['conversations']], data_args, cur_token_len)
213 |
214 | data_dict = {}
215 | data_dict['sources'] = sources
216 | data_dict['image'] = image
217 | data_dict['masks'] = mask
218 | return data_dict
219 |
220 |
221 | if __name__ == "__main__":
222 | parser = argparse.ArgumentParser(description='osprey demo', formatter_class=argparse.RawTextHelpFormatter)
223 | parser.add_argument('--model', help='path to osprey model', default='/path/to/osprey-7b')
224 | parser.add_argument('--bert', help='path to bert model', default='./all-MiniLM-L6-v2')
225 | parser.add_argument('--img', help='path to coco imgs', default='/path/to/all_coco_imgs')
226 | parser.add_argument('--json', help='path to lvis/paco val json file', default='./paco_val_1k_category.json')
227 | args = parser.parse_args()
228 |
229 | lvis_paco_eval = LVIS_PACO_EVAL(args.model, args.bert)
230 | lvis_paco_eval.eval(args.img, args.json)
231 |
232 |
--------------------------------------------------------------------------------
/osprey/eval/osprey_generate_gpt_description_answer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from functools import partial
6 | from transformers import AutoTokenizer, CLIPImageProcessor
7 | from osprey.constants import IMAGE_TOKEN_INDEX
8 | from osprey.conversation import conv_templates, SeparatorStyle
9 | from osprey.mm_utils import tokenizer_image_token
10 | from osprey.train.train import preprocess_multimodal
11 | from osprey.train.train import DataArguments
12 | from osprey.model.language_model.osprey_llama import OspreyLlamaForCausalLM
13 | from pycocotools import mask as maskUtils
14 | import numpy as np
15 | from PIL import Image
16 | import cv2
17 |
18 | data_args = DataArguments()
19 | data_args.mm_use_im_start_end = False
20 | data_args.is_multimodal = True
21 |
22 | def annToMask(ann, h, w):
23 | rles = maskUtils.frPyObjects(ann, h, w)
24 | rle = maskUtils.merge(rles)
25 | m = maskUtils.decode(rle)
26 | return m
27 |
28 | class GPT_EVAL():
29 | def __init__(self, model_path):
30 | model_path = os.path.expanduser(model_path)
31 |
32 | self.tokenizer = AutoTokenizer.from_pretrained(
33 | model_path,
34 | model_max_length=2048,
35 | padding_side="right",
36 | use_fast=True
37 | )
38 | self.model = OspreyLlamaForCausalLM.from_pretrained(
39 | model_path,
40 | torch_dtype=torch.bfloat16,
41 | ).cuda()
42 | self.tokenizer.pad_token = self.tokenizer.unk_token
43 |
44 | self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":512}, resample=3, do_center_crop=True, crop_size={"height": 512, "width": 512},
45 | do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
46 | image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
47 |
48 | spi_tokens = ['', '']
49 | self.tokenizer.add_tokens(spi_tokens, special_tokens=True)
50 |
51 | for m in self.model.modules():
52 | m.tokenizer = self.tokenizer
53 |
54 | vision_tower = self.model.get_vision_tower()
55 | if not vision_tower.is_loaded:
56 | vision_tower.load_model()
57 | vision_tower.to(dtype=torch.float16, device='cuda')
58 |
59 |
60 | def eval(self, root_path, ann_file):
61 | final = []
62 | anns = json.load(open(ann_file))
63 |
64 | for i, ann in enumerate(anns):
65 | print(i)
66 |
67 | model_answer = {}
68 | model_answer["question_id"] = ann["question_id"]
69 | model_answer["image"] = ann["image"]
70 | model_answer["category"] = ann["category"]
71 | img_path = os.path.join(root_path, ann["image"])
72 | img = cv2.imread(img_path)
73 | question = ann["text"]
74 |
75 | mask_r = ann['annotation']['segmentation']
76 | height, width = img.shape[:2]
77 |
78 | if isinstance(mask_r, list):
79 | mask = annToMask(mask_r, height, width)
80 | else:
81 | mask = maskUtils.decode(mask_r)
82 | mask = torch.from_numpy(mask).unsqueeze(0)
83 |
84 | x1, y1, w, h = ann['annotation']['bbox']
85 | bbox = np.array([x1, y1, x1 + w, y1 + h])
86 | bbox = torch.from_numpy(bbox)
87 |
88 | init_inputs = get_init_inputs(img_path,
89 | self.image_processor,
90 | mask=mask,
91 | question=question,
92 | round_ids=0,
93 | last_round_source={},
94 | )
95 |
96 | masks = init_inputs['masks'].cuda()
97 | image = init_inputs['image']
98 | conv = conv_templates['osprey_v1'].copy()
99 | qs = init_inputs['sources'][0][0]['value']
100 |
101 | conv.append_message(conv.roles[0], qs)
102 | conv.append_message(conv.roles[1], None)
103 | prompt = conv.get_prompt()
104 |
105 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
106 |
107 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
108 |
109 | self.model.model.tokenizer = self.tokenizer
110 |
111 |
112 | with torch.inference_mode():
113 |
114 | self.model.orig_forward = self.model.forward
115 | self.model.forward = partial(self.model.orig_forward,
116 | img_metas=[None],
117 | masks=[masks.half()])
118 |
119 | output_ids = self.model.generate(
120 | input_ids,
121 | images=image.unsqueeze(0).half().cuda(),
122 | do_sample=True,
123 | temperature=0.2,
124 | max_new_tokens=1024,
125 | use_cache=True,
126 | num_beams=1,
127 | )
128 |
129 | self.model.forward = self.model.orig_forward
130 |
131 | input_token_len = input_ids.shape[1]
132 | n_diff_input_output = (
133 | input_ids != output_ids[:, :input_token_len]).sum().item()
134 | if n_diff_input_output > 0:
135 | print(
136 | f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
137 | outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
138 | skip_special_tokens=True)[0]
139 |
140 | outputs = outputs.strip()
141 | if outputs.endswith(stop_str):
142 | outputs = outputs[:-len(stop_str)]
143 | outputs = outputs.strip()
144 | if ':' in outputs:
145 | outputs = outputs.split(':')[1]
146 |
147 | model_answer['text'] = outputs
148 |
149 | print(outputs)
150 | final.append(model_answer)
151 |
152 | final_ = json.dumps(final)
153 | with open('description/osprey_answer.json','w') as fw:
154 | fw.write(final_)
155 | fw.close()
156 |
157 |
158 | def get_init_inputs(img_path,
159 | processor,
160 | mask,
161 | question=None,
162 | round_ids=0,
163 | last_round_source=None):
164 |
165 | if round_ids == 0:
166 |
167 | image = Image.open(img_path).convert('RGB')
168 |
169 | image = processor.preprocess(image,
170 | do_center_crop=False,
171 | return_tensors='pt')['pixel_values'][0]
172 |
173 | image = torch.nn.functional.interpolate(image.unsqueeze(0),
174 | size=(512, 512),
175 | mode='bilinear',
176 | align_corners=False).squeeze(0)
177 |
178 | else:
179 | image = last_round_source['image']
180 |
181 | cur_token_len = (image.shape[1] // 16) * (image.shape[2] // 16)
182 |
183 | mask = mask.to(image.device)
184 |
185 | begin_str = """.\nThis provides an overview of the picture.\n"""
186 |
187 | sources = dict()
188 | sources['conversations'] = []
189 | question = question.replace('','')
190 |
191 | sources['conversations'].append({'from': 'human', 'value': begin_str+question})
192 |
193 | sources = preprocess_multimodal([sources['conversations']], data_args, cur_token_len)
194 |
195 | data_dict = {}
196 | data_dict['sources'] = sources
197 | data_dict['image'] = image
198 | data_dict['masks'] = mask
199 |
200 | return data_dict
201 |
202 |
203 | if __name__ == "__main__":
204 | parser = argparse.ArgumentParser(description='osprey generate gpt answer', formatter_class=argparse.RawTextHelpFormatter)
205 | parser.add_argument('--model', help='path to osprey model', default='/path/to/osprey-7b')
206 | parser.add_argument('--coco-img', help='path to coco imgs', default='/path/to/coco_all_imgs/')
207 | parser.add_argument('--json', help='path to question json file', default='./description/questions.json')
208 | args = parser.parse_args()
209 |
210 | gpt_eval = GPT_EVAL(args.model)
211 | gpt_eval.eval(args.coco_img, args.json)
212 |
--------------------------------------------------------------------------------
/osprey/eval/pope/evaluate.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 |
4 | parser = argparse.ArgumentParser(description='evaluate pope', formatter_class=argparse.RawTextHelpFormatter)
5 | parser.add_argument('--ans-file', default='pope/coco_pope_random_answers.json')
6 | parser.add_argument('--label-file', default='pope/coco_pope_random.json')
7 |
8 | args = parser.parse_args()
9 |
10 | answers = [json.loads(q) for q in open(args.ans_file, 'r')]
11 | label_list = [json.loads(q)['label'] for q in open(args.label_file, 'r')]
12 |
13 | for answer in answers:
14 | text = answer['answer']
15 |
16 | # Only keep the first sentence
17 | if text.find('.') != -1:
18 | text = text.split('.')[0]
19 |
20 | text = text.replace(',', '')
21 | words = text.split(' ')
22 | if 'No' in words or 'not' in words or 'no' in words:
23 | answer['answer'] = 'no'
24 | else:
25 | answer['answer'] = 'yes'
26 |
27 | for i in range(len(label_list)):
28 | if label_list[i] == 'no':
29 | label_list[i] = 0
30 | else:
31 | label_list[i] = 1
32 |
33 | pred_list = []
34 | for answer in answers:
35 | if answer['answer'] == 'no':
36 | pred_list.append(0)
37 | else:
38 | pred_list.append(1)
39 |
40 | pos = 1
41 | neg = 0
42 | yes_ratio = pred_list.count(1) / len(pred_list)
43 |
44 | TP, TN, FP, FN = 0, 0, 0, 0
45 | for pred, label in zip(pred_list, label_list):
46 | if pred == pos and label == pos:
47 | TP += 1
48 | elif pred == pos and label == neg:
49 | FP += 1
50 | elif pred == neg and label == neg:
51 | TN += 1
52 | elif pred == neg and label == pos:
53 | FN += 1
54 |
55 | print('TP\tFP\tTN\tFN\t')
56 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
57 |
58 | precision = float(TP) / float(TP + FP)
59 | recall = float(TP) / float(TP + FN)
60 | f1 = 2*precision*recall / (precision + recall)
61 | acc = (TP + TN) / (TP + TN + FP + FN)
62 | print('Accuracy: {}'.format(acc))
63 | print('Precision: {}'.format(precision))
64 | print('Recall: {}'.format(recall))
65 | print('F1 score: {}'.format(f1))
66 | print('Yes ratio: {}'.format(yes_ratio))
--------------------------------------------------------------------------------
/osprey/eval/pope_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | from functools import partial
7 | from transformers import AutoTokenizer, CLIPImageProcessor
8 | from osprey.constants import IMAGE_TOKEN_INDEX
9 | from osprey.conversation import conv_templates, SeparatorStyle
10 | from osprey.mm_utils import tokenizer_image_token
11 | from osprey.train.train import preprocess_multimodal
12 | from osprey.train.train import DataArguments
13 | from osprey.model.language_model.osprey_llama import OspreyLlamaForCausalLM
14 | import numpy as np
15 | from PIL import Image
16 | import argparse
17 |
18 | data_args = DataArguments()
19 | data_args.mm_use_im_start_end = False
20 | data_args.is_multimodal = True
21 |
22 |
23 | class POPE_EVAL():
24 | def __init__(self, model_path):
25 | model_path = os.path.expanduser(model_path)
26 |
27 | self.tokenizer = AutoTokenizer.from_pretrained(
28 | model_path,
29 | model_max_length=2048,
30 | padding_side="right",
31 | use_fast=True
32 | )
33 | self.model = OspreyLlamaForCausalLM.from_pretrained(
34 | model_path,
35 | torch_dtype=torch.bfloat16,
36 | ).cuda()
37 | self.tokenizer.pad_token = self.tokenizer.unk_token
38 |
39 | self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":512}, resample=3, do_center_crop=True, crop_size={"height": 512, "width": 512},
40 | do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
41 | image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
42 |
43 | spi_tokens = ['', '']
44 | self.tokenizer.add_tokens(spi_tokens, special_tokens=True)
45 |
46 | for m in self.model.modules():
47 | m.tokenizer = self.tokenizer
48 |
49 | vision_tower = self.model.get_vision_tower()
50 | if not vision_tower.is_loaded:
51 | vision_tower.load_model()
52 | vision_tower.to(dtype=torch.float16, device='cuda')
53 |
54 |
55 | def eval(self, root_path, ann_file, answer_file):
56 | data_all = [json.loads(l) for l in open(ann_file, 'r')]
57 | ans_file = open(answer_file, 'w')
58 |
59 | for data in tqdm(data_all):
60 | try:
61 | img_path = os.path.join(root_path, data['image'])
62 | image = Image.open(img_path).convert('RGB')
63 | except:
64 | img_path = os.path.join(root_path, data['image'].split('_')[-1])
65 | image = Image.open(img_path).convert('RGB')
66 |
67 | init_inputs = get_init_inputs(image,
68 | self.image_processor,
69 | data['text']
70 | )
71 |
72 | image = init_inputs['image']
73 |
74 | conv = conv_templates['osprey_v1'].copy()
75 | qs = init_inputs['sources'][0][0]['value']
76 |
77 | conv.append_message(conv.roles[0], qs)
78 | conv.append_message(conv.roles[1], None)
79 | prompt = conv.get_prompt()
80 |
81 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
82 |
83 | self.model.model.tokenizer = self.tokenizer
84 |
85 | with torch.inference_mode():
86 |
87 | self.model.orig_forward = self.model.forward
88 | self.model.forward = partial(self.model.orig_forward,
89 | img_metas=[None],
90 | )
91 |
92 | output_ids = self.model.generate(
93 | input_ids,
94 | images=image.unsqueeze(0).half().cuda(),
95 | do_sample=True,
96 | temperature=0.2,
97 | max_new_tokens=1024,
98 | use_cache=True,
99 | num_beams=1,
100 | )
101 |
102 | self.model.forward = self.model.orig_forward
103 |
104 | input_token_len = input_ids.shape[1]
105 | n_diff_input_output = (
106 | input_ids != output_ids[:, :input_token_len]).sum().item()
107 | if n_diff_input_output > 0:
108 | print(
109 | f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
110 | outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
111 | skip_special_tokens=True)[0]
112 |
113 |
114 | ans_file.write(json.dumps({"question": data['text'],
115 | "answer": outputs.lower()})+"\n")
116 |
117 |
118 | def get_init_inputs(image,
119 | processor,
120 | input_question):
121 |
122 | image = processor.preprocess(image,
123 | do_center_crop=False,
124 | return_tensors='pt')['pixel_values'][0]
125 |
126 | image = torch.nn.functional.interpolate(image.unsqueeze(0),
127 | size=(512, 512),
128 | mode='bilinear',
129 | align_corners=False).squeeze(0)
130 |
131 |
132 | cur_token_len = (image.shape[1] // 16) * (image.shape[2] // 16)
133 |
134 | sources = dict()
135 | sources['conversations'] = []
136 |
137 | question = '\n'+input_question
138 |
139 |
140 | sources['conversations'].append({'from': 'human', 'value': question})
141 |
142 | sources = preprocess_multimodal([sources['conversations']], data_args, cur_token_len)
143 |
144 | data_dict = {}
145 | data_dict['sources'] = sources
146 | data_dict['image'] = image
147 | return data_dict
148 |
149 |
150 | if __name__ == "__main__":
151 | parser = argparse.ArgumentParser(description='osprey demo', formatter_class=argparse.RawTextHelpFormatter)
152 | parser.add_argument('--model', help='path to osprey model', default='osprey-7b')
153 | parser.add_argument('--img', help='path to coco imgs', default='/path/to/coco-imgs')
154 | parser.add_argument('--json', help='path to pope val json file', default='pope/coco_pope_random.json') #'pope/coco_pope_adversarial.json', 'pope/coco_pope_popular.json', 'pope/coco_pope_random.json'
155 | parser.add_argument('--answer', help='path to answer json file', default='./osprey_pope_random_answer.json')
156 |
157 | args = parser.parse_args()
158 |
159 | POPE_EVAL = POPE_EVAL(args.model)
160 | POPE_EVAL.eval(args.img, args.json, args.answer)
161 |
162 |
--------------------------------------------------------------------------------
/osprey/eval/pope_eval.sh:
--------------------------------------------------------------------------------
1 |
2 | for type in random popular adversarial
3 | do
4 | python pope_eval.py --model path/to/osprey-chat-7b \
5 | --img path/to/coco_imgs --json pope/coco_pope_${type}.json \
6 | --answer pope/coco_pope_${type}_answers.json
7 | done
8 |
9 | for type in random popular adversarial
10 | do
11 | echo "Evaluating pope on ${type} data..."
12 | python pope/evaluate.py --ans-file pope/coco_pope_${type}_answers.json \
13 | --label-file pope/coco_pope_${type}.json
14 | done
--------------------------------------------------------------------------------
/osprey/eval/refcocog_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | from functools import partial
7 | from transformers import AutoTokenizer, CLIPImageProcessor
8 | from osprey.constants import IMAGE_TOKEN_INDEX
9 | from osprey.conversation import conv_templates, SeparatorStyle
10 | from osprey.mm_utils import tokenizer_image_token
11 | from osprey.train.train import preprocess_multimodal
12 | from osprey.train.train import DataArguments
13 | from osprey.model.language_model.osprey_llama import OspreyLlamaForCausalLM
14 | from pycocotools.coco import COCO
15 | from pycocotools import mask as maskUtils
16 | import numpy as np
17 | from PIL import Image
18 | data_args = DataArguments()
19 | data_args.mm_use_im_start_end = False
20 | data_args.is_multimodal = True
21 |
22 | def annToMask(ann, h, w):
23 | rles = maskUtils.frPyObjects(ann, h, w)
24 | rle = maskUtils.merge(rles)
25 | m = maskUtils.decode(rle)
26 | return m
27 |
28 | class REFCOCOG_EVAL():
29 | def __init__(self, model_path,):
30 | model_path = os.path.expanduser(model_path)
31 |
32 | self.tokenizer = AutoTokenizer.from_pretrained(
33 | model_path,
34 | model_max_length=2048,
35 | padding_side="right",
36 | use_fast=True
37 | )
38 | self.model = OspreyLlamaForCausalLM.from_pretrained(
39 | model_path,
40 | torch_dtype=torch.bfloat16,
41 | ).cuda()
42 | self.tokenizer.pad_token = self.tokenizer.unk_token
43 |
44 | self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":512}, resample=3, do_center_crop=True, crop_size={"height": 512, "width": 512},
45 | do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
46 | image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
47 |
48 | spi_tokens = ['', '']
49 | self.tokenizer.add_tokens(spi_tokens, special_tokens=True)
50 |
51 | for m in self.model.modules():
52 | m.tokenizer = self.tokenizer
53 |
54 | vision_tower = self.model.get_vision_tower()
55 | if not vision_tower.is_loaded:
56 | vision_tower.load_model()
57 | vision_tower.to(dtype=torch.float16, device='cuda')
58 |
59 | def forward(self, root_path, ann_file, gt_file='captions_refcocog_gt.json', caption_file='captions_refcocog_osprey.json'):
60 | self.captions_all = []
61 | self.gt_all = {}
62 | self.gt_all['images'] = []
63 | self.gt_all['annotations'] = []
64 | self.root_path = root_path
65 | self.coco = COCO(ann_file)
66 | self.img_ids = self.coco.getImgIds()
67 | for i, img in enumerate(tqdm(self.img_ids)):
68 | data = self.coco.loadImgs([img])[0]
69 | self.forward_single(data)
70 |
71 | final = json.dumps(self.captions_all)
72 | with open(caption_file,'w') as fw:
73 | fw.write(final)
74 | fw.close()
75 |
76 | final = json.dumps(self.gt_all)
77 | with open(gt_file,'w') as fw:
78 | fw.write(final)
79 | fw.close()
80 |
81 |
82 | def forward_single(self, inputs):
83 |
84 | img_path = os.path.join(self.root_path, inputs['file_name'].split('_')[-1])
85 | height = inputs['height']
86 | width = inputs['width']
87 | round_ids = 0
88 | last_source = dict()
89 | annotations_ids = self.coco.getAnnIds([inputs['id']])
90 | annotations = self.coco.loadAnns(annotations_ids)
91 | for i in range(len(annotations)):
92 | caption = {}
93 | gt = {}
94 | ann = annotations[i]
95 | mask_r = ann['segmentation']
96 |
97 | if isinstance(mask_r, list):
98 | mask = annToMask(mask_r, height, width)
99 | else:
100 | mask = maskUtils.decode(mask_r)
101 | mask = torch.from_numpy(mask).unsqueeze(0)
102 |
103 | init_inputs = get_init_inputs(img_path,
104 | self.image_processor,
105 | self.tokenizer,
106 | mask=mask,
107 | round_ids=round_ids,
108 | last_round_source=last_source,
109 | )
110 |
111 | round_ids += 1
112 | last_source = init_inputs
113 |
114 | image = init_inputs['image']
115 |
116 | masks = init_inputs['masks'].cuda()
117 |
118 | conv = conv_templates['osprey_v1'].copy()
119 | qs = init_inputs['sources'][0][0]['value']
120 |
121 | conv.append_message(conv.roles[0], qs)
122 | conv.append_message(conv.roles[1], None)
123 | prompt = conv.get_prompt()
124 |
125 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
126 |
127 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
128 |
129 | self.model.model.tokenizer = self.tokenizer
130 |
131 |
132 | with torch.inference_mode():
133 |
134 | self.model.orig_forward = self.model.forward
135 | self.model.forward = partial(self.model.orig_forward,
136 | img_metas=[None],
137 | masks=[masks.half()])
138 |
139 | output_ids = self.model.generate(
140 | input_ids,
141 | images=image.unsqueeze(0).half().cuda(),
142 | do_sample=True,
143 | temperature=0.2,
144 | max_new_tokens=1024,
145 | use_cache=True,
146 | num_beams=1,
147 | )
148 |
149 | self.model.forward = self.model.orig_forward
150 |
151 | input_token_len = input_ids.shape[1]
152 | n_diff_input_output = (
153 | input_ids != output_ids[:, :input_token_len]).sum().item()
154 | if n_diff_input_output > 0:
155 | print(
156 | f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
157 | outputs = self.tokenizer.batch_decode(output_ids[:, input_token_len:],
158 | skip_special_tokens=True)[0]
159 |
160 | outputs = outputs.strip()
161 | if outputs.endswith(stop_str):
162 | outputs = outputs[:-len(stop_str)]
163 | outputs = outputs.strip()
164 | if ':' in outputs:
165 | outputs = outputs.split(':')[1]
166 |
167 | print(outputs)
168 | outputs = outputs.replace('.', '.\n')
169 | caption['image_id'] = str(ann['id'])
170 | caption['caption'] = outputs
171 | gt['id'] = str(ann['id'])
172 | gt['image_id'] = str(ann['id'])
173 | gt['caption'] = inputs['caption']
174 |
175 | self.captions_all.append(caption)
176 | self.gt_all['annotations'].append(gt)
177 | self.gt_all['images'].append({'id':str(ann['id'])})
178 |
179 |
180 |
181 | def get_init_inputs(img_path,
182 | processor,
183 | pred_bboxes,
184 | mask,
185 | round_ids=0,
186 | last_round_source=None):
187 |
188 | if round_ids == 0:
189 |
190 | image = Image.open(img_path).convert('RGB')
191 |
192 | image = processor.preprocess(image,
193 | do_center_crop=False,
194 | return_tensors='pt')['pixel_values'][0]
195 |
196 | image = torch.nn.functional.interpolate(image.unsqueeze(0),
197 | size=(512, 512),
198 | mode='bilinear',
199 | align_corners=False).squeeze(0)
200 |
201 | else:
202 | image = last_round_source['image']
203 |
204 | cur_token_len = (image.shape[1] // 16) * (image.shape[2] // 16)
205 |
206 | mask = mask.to(image.device)
207 |
208 | begin_str = """.\nThis provides an overview of the picture.\n"""
209 |
210 | sources = dict()
211 | sources['conversations'] = []
212 | question = 'Please give me a short description of .'
213 |
214 | sources['conversations'].append({'from': 'human', 'value': begin_str+question})
215 |
216 | sources = preprocess_multimodal([sources['conversations']], data_args, cur_token_len)
217 |
218 | data_dict = {}
219 | data_dict['sources'] = sources
220 | data_dict['image'] = image
221 | data_dict['masks'] = mask
222 |
223 | return data_dict
224 |
225 |
226 | if __name__ == "__main__":
227 | parser = argparse.ArgumentParser(description='osprey demo', formatter_class=argparse.RawTextHelpFormatter)
228 | parser.add_argument('--model', help='path to osprey model', default='path/to/Osprey-7B-refcocog-fintune')
229 | parser.add_argument('--img', help='path to coco imgs', default='path/to/coco_all_imgs/')
230 | parser.add_argument('--json', help='path to refcocog val json file', default='./finetune_refcocog_val_with_mask.json')
231 | args = parser.parse_args()
232 |
233 | refcocog_eval = REFCOCOG_EVAL(args.model)
234 | refcocog_eval.forward(args.img, args.json)
235 |
236 |
--------------------------------------------------------------------------------
/osprey/eval/rule.json:
--------------------------------------------------------------------------------
1 | {
2 | "description": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. In addition, specific object locations within the image are given, along with detailed coordinates. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. Also, several region description are given, each describing a box region of image, with detailed coordinates. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}
3 | }
--------------------------------------------------------------------------------
/osprey/eval/summarize_gpt_score.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/haotian-liu/LLaVA/blob/main/llava/eval/summarize_gpt_review.py
3 | """
4 |
5 | import json
6 | import os
7 | from collections import defaultdict
8 |
9 | import numpy as np
10 |
11 | import argparse
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
15 | parser.add_argument('-d', '--dir', default=None)
16 | parser.add_argument('-f', '--files', nargs='*', default=None)
17 | parser.add_argument('-i', '--ignore', nargs='*', default=None)
18 | parser.add_argument('-s', '--save', action='store_true')
19 | return parser.parse_args()
20 |
21 |
22 | if __name__ == '__main__':
23 | args = parse_args()
24 |
25 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl')]
26 |
27 | metrics = []
28 | for review_file in sorted(review_files):
29 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
30 | scores = defaultdict(list)
31 | print(config)
32 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
33 | for review_str in f:
34 | review = json.loads(review_str)
35 | if args.ignore is not None and review['question_id'] in args.ignore:
36 | continue
37 | if 'category' in review:
38 | scores[review['category']].append(review['tuple'])
39 | scores['all'].append(review['tuple'])
40 | else:
41 | if 'tuple' in review:
42 | scores['all'].append(review['tuple'])
43 | else:
44 | scores['all'].append(review['score'])
45 | summ_dict = defaultdict(list)
46 | for k, v in sorted(scores.items()):
47 | stats = np.asarray(v).mean(0).tolist()
48 | stats = [round(x, 3) for x in stats]
49 | # print(k, stats, round(stats[1]/stats[0]*100, 1))
50 | print(k, round(stats[1]/stats[0]*100, 2))
51 | summ_dict[k] = round(stats[1]/stats[0]*100, 2)
52 | print('=================================')
53 | metrics.append(summ_dict)
54 |
55 | if args.save:
56 | with open(os.path.join(args.dir, 'metric.json'), 'w') as f:
57 | json.dump(metrics, f, indent=2)
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/osprey/eval/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/CircleRadon/Osprey/deb0e57d7e771e8d2a8652236b0e348efb9b5c6a/osprey/eval/utils/__init__.py
--------------------------------------------------------------------------------
/osprey/eval/utils/ade20k_150_with_prompt_eng.txt:
--------------------------------------------------------------------------------
1 | 0:invalid_class_id
2 | 1:wall,walls,brick wall,stone wall,interior wall
3 | 2:building,buildings,edifice,edifices
4 | 3:sky,clouds
5 | 4:floor,flooring
6 | 5:tree,trees
7 | 6:ceiling
8 | 7:road,route,street,roads,streets,routes
9 | 8:bed,beds
10 | 9:windowpane,window,windows
11 | 10:grass,grass field
12 | 11:cabinet,cabinets,wall mounted cabine
13 | 12:sidewalk,pavement
14 | 13:person,child,girl,boy,woman,man,people,children,girls,boys,women,men
15 | 14:earth,ground
16 | 15:door,double door,doors
17 | 16:table,tables,tablecloth
18 | 17:mountain,mount,mountains
19 | 18:plant,flora,plant life,plants,bushes
20 | 19:curtain,drape,drapery,mantle,pall
21 | 20:chair,chairs
22 | 21:car,automobile,cars
23 | 22:water
24 | 23:painting,picture,paintings,pictures,wallart,framed canvas
25 | 24:sofa,couch,sofas,couches
26 | 25:shelf,shelves
27 | 26:house exterior
28 | 27:sea,ocean
29 | 28:mirror,mirrors
30 | 29:rug,carpet,carpeting
31 | 30:field
32 | 31:armchair,armchairs
33 | 32:seat,seats
34 | 33:fence,fencing
35 | 34:desk,desks
36 | 35:rock,stone,rocks,stones
37 | 36:wardrobe,closet,press,wardrobes,closets
38 | 37:lamp,lamps
39 | 38:bathtub,bathing tub,bath,tub
40 | 39:railing,rail
41 | 40:cushion,cushions
42 | 41:pedestal
43 | 42:box,boxes
44 | 43:column,pillar
45 | 44:signboard,sign,signboards,signs
46 | 45:chest of drawers,chest,bureau,dresser
47 | 46:counter
48 | 47:sand
49 | 48:sink
50 | 49:skyscraper,skyscrapers
51 | 50:fireplace,hearth,open fireplace
52 | 51:refrigerator,icebox
53 | 52:grandstand,covered stand
54 | 53:path
55 | 54:stairs,steps
56 | 55:runway
57 | 56:case,display case,showcase,vitrine
58 | 57:pool table,billiard table,snooker table
59 | 58:pillow,pillows
60 | 59:screen door,shower door
61 | 60:stairway,staircase
62 | 61:river
63 | 62:bridge,span
64 | 63:bookcase
65 | 64:window screen,door screen
66 | 65:coffee table,cocktail table
67 | 66:toilet,commode,crapper,potty
68 | 67:flower,flowers
69 | 68:book,books
70 | 69:hill
71 | 70:bench,benches
72 | 71:countertop,counter top,worktop
73 | 72:stove,kitchen stove,kitchen range,kitchen range,cooking stove
74 | 73:palm tree,palm trees
75 | 74:kitchen island
76 | 75:computer,computing machine,computing device,data processor,electronic computer,information processing system
77 | 76:swivel chair
78 | 77:boat
79 | 78:bar
80 | 79:arcade machine,arcade machines
81 | 80:hovel,hut,hutch,shack,shanty
82 | 81:bus,autobus,double-decker,jitney,motorbus,motorcoach,omnibus,passenger vehicle
83 | 82:towel
84 | 83:light bulb,lightbulb,bulb,incandescent lamp,electric light,electric-light bulb
85 | 84:truck,motortruck
86 | 85:tower,towers
87 | 86:chandelier,pendant,pendent
88 | 87:awning,sunshade,sunblind
89 | 88:streetlight,street lamp
90 | 89:booth,cubicle,stall,kiosk
91 | 90:television receiver,television,television set,tv,tv set
92 | 91:airplane,aeroplane,airplanes,aeroplanes
93 | 92:dirt track
94 | 93:apparel,wearing apparel,dress,clothes
95 | 94:pole
96 | 95:land,soil
97 | 96:bannister,banister,balustrade,balusters,handrail
98 | 97:escalator,moving staircase,moving stairway
99 | 98:ottoman,pouf,pouffe,puff,hassock
100 | 99:bottle,bottles,water bottle
101 | 100:buffet,sideboard
102 | 101:poster,posting,placard,notice,bill,card
103 | 102:stage
104 | 103:van
105 | 104:ship
106 | 105:fountain
107 | 106:conveyer belt,conveyor belt,conveyer,conveyor,transporter
108 | 107:canopy
109 | 108:washer,automatic washer,washing machine
110 | 109:plaything,toy,toys
111 | 110:swimming pool,swimming bath
112 | 111:stool,stools
113 | 112:barrel,cask,barrels,casks
114 | 113:basket,handbasket
115 | 114:waterfall,falls
116 | 115:tent,collapsible shelter
117 | 116:bag,bags,gift bag,paper bag
118 | 117:minibike,motorbike
119 | 118:cradle
120 | 119:oven
121 | 120:ball,balls
122 | 121:food,solid food
123 | 122:step,stair
124 | 123:tank,storage tank
125 | 124:trade name,brand name,brand,marque
126 | 125:microwave,microwave oven
127 | 126:plant pots,plant pot,flower pot,flowerpot,planter
128 | 127:animal,animate being,dog,cat,horse,cow,sheep,zebra,girraffe,bird
129 | 128:bicycle,bike
130 | 129:lake
131 | 130:dishwasher,dish washer,dishwashing machine
132 | 131:projection screen
133 | 132:blanket,cover
134 | 133:sculpture,sculptures
135 | 134:exhaust hood
136 | 135:sconce,sconce lamp,sconce light
137 | 136:vase,vases
138 | 137:traffic light,traffic signal,traffic lights
139 | 138:tray,trays
140 | 139:ashcan,trash can,garbage can,wastebin,ash bin,ash-bin,ashbin,dustbin,trash barrel,trash bin
141 | 140:ceiling fan,floor fan
142 | 141:pier,wharf,wharfage,dock
143 | 142:crt screen
144 | 143:plate,plates
145 | 144:monitor,monitoring device,monitors
146 | 145:bulletin board,notice board
147 | 146:shower
148 | 147:radiator
149 | 148:cup,cups,drinking glass,drinking glasses
150 | 149:clock
151 | 150:flag,flags
--------------------------------------------------------------------------------
/osprey/eval/utils/cityscapes_with_prompt_eng.txt:
--------------------------------------------------------------------------------
1 | 0:road,railroad
2 | 1:sidewalk,pavement
3 | 2:building,buildings,edifice,edifices,house,ceiling
4 | 3:wall,walls,brick wall,stone wall,tile wall,wood wall
5 | 4:fence,fences
6 | 5:pole,poles
7 | 6:traffic light,traffic lights
8 | 7:traffic sign,stop sign
9 | 8:vegetation,tree,trees,palm tree,bushes
10 | 9:terrain,river,sand,sea,snow,water,mountain,grass,dirt,rock
11 | 10:sky,clouds
12 | 11:person
13 | 12:rider
14 | 13:car,cars
15 | 14:truck,trucks
16 | 15:bus,buses
17 | 16:train,trains,locomotive,locomotives,freight train
18 | 17:motorcycle,motorcycles
19 | 18:bicycle,bicycles,bike,bikes
--------------------------------------------------------------------------------
/osprey/eval/utils/instance_evaluation.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/evaluation/instance_evaluation.py
3 | """
4 |
5 | import contextlib
6 | import copy
7 | import io
8 | import itertools
9 | import json
10 | import logging
11 | import numpy as np
12 | import os
13 | import pickle
14 | from collections import OrderedDict
15 | import pycocotools.mask as mask_util
16 | import torch
17 | from pycocotools.coco import COCO
18 | from pycocotools.cocoeval import COCOeval
19 | from tabulate import tabulate
20 |
21 | import detectron2.utils.comm as comm
22 | from detectron2.config import CfgNode
23 | from detectron2.data import MetadataCatalog
24 | from detectron2.data.datasets.coco import convert_to_coco_json
25 | from detectron2.evaluation.coco_evaluation import COCOEvaluator, _evaluate_predictions_on_coco
26 | from detectron2.evaluation.fast_eval_api import COCOeval_opt
27 | from detectron2.structures import Boxes, BoxMode, pairwise_iou
28 | from detectron2.utils.file_io import PathManager
29 | from detectron2.utils.logger import create_small_table
30 |
31 |
32 | # modified from COCOEvaluator for instance segmetnat
33 | class InstanceSegEvaluator(COCOEvaluator):
34 | """
35 | Evaluate AR for object proposals, AP for instance detection/segmentation, AP
36 | for keypoint detection outputs using COCO's metrics.
37 | See http://cocodataset.org/#detection-eval and
38 | http://cocodataset.org/#keypoints-eval to understand its metrics.
39 | The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means
40 | the metric cannot be computed (e.g. due to no predictions made).
41 |
42 | In addition to COCO, this evaluator is able to support any bounding box detection,
43 | instance segmentation, or keypoint detection dataset.
44 | """
45 |
46 | def _eval_predictions(self, predictions, img_ids=None):
47 | """
48 | Evaluate predictions. Fill self._results with the metrics of the tasks.
49 | """
50 | self._logger.info("Preparing results for COCO format ...")
51 | coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
52 | tasks = self._tasks or self._tasks_from_predictions(coco_results)
53 |
54 | # unmap the category ids for COCO
55 | if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"):
56 | dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id
57 | # all_contiguous_ids = list(dataset_id_to_contiguous_id.values())
58 | # num_classes = len(all_contiguous_ids)
59 | # assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1
60 |
61 | reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()}
62 | for result in coco_results:
63 | category_id = result["category_id"]
64 | # assert category_id < num_classes, (
65 | # f"A prediction has class={category_id}, "
66 | # f"but the dataset only has {num_classes} classes and "
67 | # f"predicted class id should be in [0, {num_classes - 1}]."
68 | # )
69 | assert category_id in reverse_id_mapping, (
70 | f"A prediction has class={category_id}, "
71 | f"but the dataset only has class ids in {dataset_id_to_contiguous_id}."
72 | )
73 | result["category_id"] = reverse_id_mapping[category_id]
74 |
75 | if self._output_dir:
76 | file_path = os.path.join(self._output_dir, "coco_instances_results.json")
77 | self._logger.info("Saving results to {}".format(file_path))
78 | with PathManager.open(file_path, "w") as f:
79 | f.write(json.dumps(coco_results))
80 | f.flush()
81 |
82 | if not self._do_evaluation:
83 | self._logger.info("Annotations are not available for evaluation.")
84 | return
85 |
86 | self._logger.info(
87 | "Evaluating predictions with {} COCO API...".format(
88 | "unofficial" if self._use_fast_impl else "official"
89 | )
90 | )
91 | for task in sorted(tasks):
92 | assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!"
93 | coco_eval = (
94 | _evaluate_predictions_on_coco(
95 | self._coco_api,
96 | coco_results,
97 | task,
98 | kpt_oks_sigmas=self._kpt_oks_sigmas,
99 | #use_fast_impl=self._use_fast_impl,
100 | img_ids=img_ids,
101 | max_dets_per_image=self._max_dets_per_image,
102 | )
103 | if len(coco_results) > 0
104 | else None # cocoapi does not handle empty results very well
105 | )
106 |
107 | res = self._derive_coco_results(
108 | coco_eval, task, class_names=self._metadata.get("thing_classes")
109 | )
110 | self._results[task] = res
111 |
--------------------------------------------------------------------------------
/osprey/eval/utils/register_ade20k_panoptic.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/facebookresearch/Mask2Former/blob/main/mask2former/data/datasets/register_ade20k_panoptic.py
3 | """
4 |
5 | import json
6 | import os
7 |
8 | from detectron2.data import DatasetCatalog, MetadataCatalog
9 | from detectron2.utils.file_io import PathManager
10 | from detectron2.data.datasets.coco import load_sem_seg
11 |
12 |
13 | from . import openseg_classes
14 |
15 | ADE20K_150_CATEGORIES = openseg_classes.get_ade20k_categories_with_prompt_eng()
16 |
17 | ADE20k_COLORS = [k["color"] for k in ADE20K_150_CATEGORIES]
18 |
19 | MetadataCatalog.get("openvocab_ade20k_sem_seg_train").set(
20 | stuff_colors=ADE20k_COLORS[:],
21 | )
22 |
23 | MetadataCatalog.get("openvocab_ade20k_sem_seg_val").set(
24 | stuff_colors=ADE20k_COLORS[:],
25 | )
26 |
27 |
28 | def load_ade20k_panoptic_json(json_file, image_dir, gt_dir, semseg_dir, meta):
29 | """
30 | Args:
31 | image_dir (str): path to the raw dataset. e.g., "~/coco/train2017".
32 | gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017".
33 | json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json".
34 | Returns:
35 | list[dict]: a list of dicts in Detectron2 standard format. (See
36 | `Using Custom Datasets `_ )
37 | """
38 |
39 | def _convert_category_id(segment_info, meta):
40 | if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
41 | segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
42 | segment_info["category_id"]
43 | ]
44 | segment_info["isthing"] = True
45 | else:
46 | segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
47 | segment_info["category_id"]
48 | ]
49 | segment_info["isthing"] = False
50 | return segment_info
51 |
52 | with PathManager.open(json_file) as f:
53 | json_info = json.load(f)
54 |
55 | ret = []
56 | for ann in json_info["annotations"]:
57 | image_id = ann["image_id"]
58 | # TODO: currently we assume image and label has the same filename but
59 | # different extension, and images have extension ".jpg" for COCO. Need
60 | # to make image extension a user-provided argument if we extend this
61 | # function to support other COCO-like datasets.
62 | image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg")
63 | label_file = os.path.join(gt_dir, ann["file_name"])
64 | sem_label_file = os.path.join(semseg_dir, ann["file_name"])
65 | segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]]
66 | ret.append(
67 | {
68 | "file_name": image_file,
69 | "image_id": image_id,
70 | "pan_seg_file_name": label_file,
71 | "sem_seg_file_name": sem_label_file,
72 | "segments_info": segments_info,
73 | }
74 | )
75 | assert len(ret), f"No images found in {image_dir}!"
76 | assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"]
77 | assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"]
78 | assert PathManager.isfile(ret[0]["sem_seg_file_name"]), ret[0]["sem_seg_file_name"]
79 | return ret
80 |
81 |
82 | def register_ade20k_panoptic(
83 | name, metadata, image_root, panoptic_root, semantic_root, panoptic_json, instances_json=None
84 | ):
85 | """
86 | Register a "standard" version of ADE20k panoptic segmentation dataset named `name`.
87 | The dictionaries in this registered dataset follows detectron2's standard format.
88 | Hence it's called "standard".
89 | Args:
90 | name (str): the name that identifies a dataset,
91 | e.g. "ade20k_panoptic_train"
92 | metadata (dict): extra metadata associated with this dataset.
93 | image_root (str): directory which contains all the images
94 | panoptic_root (str): directory which contains panoptic annotation images in COCO format
95 | panoptic_json (str): path to the json panoptic annotation file in COCO format
96 | sem_seg_root (none): not used, to be consistent with
97 | `register_coco_panoptic_separated`.
98 | instances_json (str): path to the json instance annotation file
99 | """
100 | panoptic_name = name
101 | DatasetCatalog.register(
102 | panoptic_name,
103 | lambda: load_ade20k_panoptic_json(
104 | panoptic_json, image_root, panoptic_root, semantic_root, metadata
105 | ),
106 | )
107 | MetadataCatalog.get(panoptic_name).set(
108 | panoptic_root=panoptic_root,
109 | image_root=image_root,
110 | panoptic_json=panoptic_json,
111 | json_file=instances_json,
112 | evaluator_type="ade20k_panoptic_seg",
113 | ignore_label=255,
114 | label_divisor=1000,
115 | **metadata,
116 | )
117 |
118 |
119 | _PREDEFINED_SPLITS_ADE20K_PANOPTIC = {
120 | "openvocab_ade20k_panoptic_train": (
121 | "ADEChallengeData2016/images/training",
122 | "ADEChallengeData2016/ade20k_panoptic_train",
123 | "ADEChallengeData2016/ade20k_panoptic_train.json",
124 | "ADEChallengeData2016/annotations_detectron2/training",
125 | "ADEChallengeData2016/ade20k_instance_train.json",
126 | ),
127 | "openvocab_ade20k_panoptic_val": (
128 | "ADEChallengeData2016/images/validation",
129 | "ADEChallengeData2016/ade20k_panoptic_val",
130 | "ADEChallengeData2016/ade20k_panoptic_val.json",
131 | "ADEChallengeData2016/annotations_detectron2/validation",
132 | "ADEChallengeData2016/ade20k_instance_val.json",
133 | ),
134 | }
135 |
136 |
137 | def get_metadata():
138 | meta = {}
139 | # The following metadata maps contiguous id from [0, #thing categories +
140 | # #stuff categories) to their names and colors. We have to replica of the
141 | # same name and color under "thing_*" and "stuff_*" because the current
142 | # visualization function in D2 handles thing and class classes differently
143 | # due to some heuristic used in Panoptic FPN. We keep the same naming to
144 | # enable reusing existing visualization functions.
145 | thing_classes = [k["name"] for k in ADE20K_150_CATEGORIES if k["isthing"] == 1]
146 | thing_colors = [k["color"] for k in ADE20K_150_CATEGORIES if k["isthing"] == 1]
147 | stuff_classes = [k["name"] for k in ADE20K_150_CATEGORIES]
148 | stuff_colors = [k["color"] for k in ADE20K_150_CATEGORIES]
149 |
150 | meta["thing_classes"] = thing_classes
151 | meta["thing_colors"] = thing_colors
152 | meta["stuff_classes"] = stuff_classes
153 | meta["stuff_colors"] = stuff_colors
154 |
155 | # Convert category id for training:
156 | # category id: like semantic segmentation, it is the class id for each
157 | # pixel. Since there are some classes not used in evaluation, the category
158 | # id is not always contiguous and thus we have two set of category ids:
159 | # - original category id: category id in the original dataset, mainly
160 | # used for evaluation.
161 | # - contiguous category id: [0, #classes), in order to train the linear
162 | # softmax classifier.
163 | thing_dataset_id_to_contiguous_id = {}
164 | stuff_dataset_id_to_contiguous_id = {}
165 |
166 | for i, cat in enumerate(ADE20K_150_CATEGORIES):
167 | if cat["isthing"]:
168 | thing_dataset_id_to_contiguous_id[cat["id"]] = i
169 | # else:
170 | # stuff_dataset_id_to_contiguous_id[cat["id"]] = i
171 |
172 | # in order to use sem_seg evaluator
173 | stuff_dataset_id_to_contiguous_id[cat["id"]] = i
174 |
175 | meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
176 | meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
177 |
178 | return meta
179 |
180 |
181 | def register_all_ade20k_panoptic(root):
182 | metadata = get_metadata()
183 | for (
184 | prefix,
185 | (image_root, panoptic_root, panoptic_json, semantic_root, instance_json),
186 | ) in _PREDEFINED_SPLITS_ADE20K_PANOPTIC.items():
187 | # The "standard" version of COCO panoptic segmentation dataset,
188 | # e.g. used by Panoptic-DeepLab
189 | register_ade20k_panoptic(
190 | prefix,
191 | metadata,
192 | os.path.join(root, image_root),
193 | os.path.join(root, panoptic_root),
194 | os.path.join(root, semantic_root),
195 | os.path.join(root, panoptic_json),
196 | os.path.join(root, instance_json),
197 | )
198 |
199 | def register_all_ade20k_semantic(root):
200 | root = os.path.join(root, "ADEChallengeData2016")
201 | for name, dirname in [("train", "training"), ("val", "validation")]:
202 | image_dir = os.path.join(root, "images", dirname)
203 | gt_dir = os.path.join(root, "annotations_detectron2", dirname)
204 | name = f"openvocab_ade20k_sem_seg_{name}"
205 | DatasetCatalog.register(
206 | name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg")
207 | )
208 | MetadataCatalog.get(name).set(
209 | stuff_classes=[x["name"] for x in ADE20K_150_CATEGORIES],
210 | image_root=image_dir,
211 | sem_seg_root=gt_dir,
212 | evaluator_type="sem_seg",
213 | ignore_label=255,
214 | )
215 |
216 | _root = os.getenv("DETECTRON2_DATASETS", "datasets")
217 | register_all_ade20k_panoptic(_root)
218 | register_all_ade20k_semantic(_root)
--------------------------------------------------------------------------------
/osprey/eval/utils/register_cityscapes_panoptic.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/datasets/cityscapes_panoptic.py
3 | """
4 |
5 | import json
6 | import logging
7 | import os
8 |
9 | from detectron2.data import DatasetCatalog, MetadataCatalog
10 | from detectron2.utils.file_io import PathManager
11 |
12 | from . import openseg_classes
13 |
14 | CITYSCAPES_CATEGORIES = openseg_classes.get_cityscapes_categories_with_prompt_eng()
15 |
16 | """
17 | This file contains functions to register the Cityscapes panoptic dataset to the DatasetCatalog.
18 | """
19 |
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | def get_cityscapes_panoptic_files(image_dir, gt_dir, json_info):
25 | files = []
26 | # scan through the directory
27 | cities = PathManager.ls(image_dir)
28 | logger.info(f"{len(cities)} cities found in '{image_dir}'.")
29 | image_dict = {}
30 | for city in cities:
31 | city_img_dir = os.path.join(image_dir, city)
32 | for basename in PathManager.ls(city_img_dir):
33 | image_file = os.path.join(city_img_dir, basename)
34 |
35 | suffix = "_leftImg8bit.png"
36 | assert basename.endswith(suffix), basename
37 | basename = os.path.basename(basename)[: -len(suffix)]
38 |
39 | image_dict[basename] = image_file
40 |
41 | for ann in json_info["annotations"]:
42 | image_file = image_dict.get(ann["image_id"], None)
43 | assert image_file is not None, "No image {} found for annotation {}".format(
44 | ann["image_id"], ann["file_name"]
45 | )
46 | label_file = os.path.join(gt_dir, ann["file_name"])
47 | segments_info = ann["segments_info"]
48 |
49 | files.append((image_file, label_file, segments_info))
50 |
51 | assert len(files), "No images found in {}".format(image_dir)
52 | assert PathManager.isfile(files[0][0]), files[0][0]
53 | assert PathManager.isfile(files[0][1]), files[0][1]
54 | return files
55 |
56 |
57 | def load_cityscapes_panoptic(image_dir, gt_dir, gt_json, meta):
58 | """
59 | Args:
60 | image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train".
61 | gt_dir (str): path to the raw annotations. e.g.,
62 | "~/cityscapes/gtFine/cityscapes_panoptic_train".
63 | gt_json (str): path to the json file. e.g.,
64 | "~/cityscapes/gtFine/cityscapes_panoptic_train.json".
65 | meta (dict): dictionary containing "thing_dataset_id_to_contiguous_id"
66 | and "stuff_dataset_id_to_contiguous_id" to map category ids to
67 | contiguous ids for training.
68 |
69 | Returns:
70 | list[dict]: a list of dicts in Detectron2 standard format. (See
71 | `Using Custom Datasets `_ )
72 | """
73 |
74 | def _convert_category_id(segment_info, meta):
75 | if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]:
76 | segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][
77 | segment_info["category_id"]
78 | ]
79 | else:
80 | segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][
81 | segment_info["category_id"]
82 | ]
83 | return segment_info
84 |
85 | assert os.path.exists(
86 | gt_json
87 | ), "Please run `python cityscapesscripts/preparation/createPanopticImgs.py` to generate label files." # noqa
88 | with open(gt_json) as f:
89 | json_info = json.load(f)
90 | files = get_cityscapes_panoptic_files(image_dir, gt_dir, json_info)
91 | ret = []
92 | for image_file, label_file, segments_info in files:
93 | sem_label_file = (
94 | image_file.replace("leftImg8bit", "gtFine").split(".")[0] + "_labelTrainIds.png"
95 | )
96 | segments_info = [_convert_category_id(x, meta) for x in segments_info]
97 | ret.append(
98 | {
99 | "file_name": image_file,
100 | "image_id": "_".join(
101 | os.path.splitext(os.path.basename(image_file))[0].split("_")[:3]
102 | ),
103 | "sem_seg_file_name": sem_label_file,
104 | "pan_seg_file_name": label_file,
105 | "segments_info": segments_info,
106 | }
107 | )
108 | assert len(ret), f"No images found in {image_dir}!"
109 | assert PathManager.isfile(
110 | ret[0]["sem_seg_file_name"]
111 | ), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa
112 | assert PathManager.isfile(
113 | ret[0]["pan_seg_file_name"]
114 | ), "Please generate panoptic annotation with python cityscapesscripts/preparation/createPanopticImgs.py" # noqa
115 | return ret
116 |
117 |
118 | # rename to avoid conflict
119 | _RAW_CITYSCAPES_PANOPTIC_SPLITS = {
120 | "openvocab_cityscapes_fine_panoptic_train": (
121 | "cityscapes/leftImg8bit/train",
122 | "cityscapes/gtFine/cityscapes_panoptic_train",
123 | "cityscapes/gtFine/cityscapes_panoptic_train.json",
124 | ),
125 | "openvocab_cityscapes_fine_panoptic_val": (
126 | "cityscapes/leftImg8bit/val",
127 | "cityscapes/gtFine/cityscapes_panoptic_val",
128 | "cityscapes/gtFine/cityscapes_panoptic_val.json",
129 | ),
130 | # "cityscapes_fine_panoptic_test": not supported yet
131 | }
132 |
133 |
134 | def register_all_cityscapes_panoptic(root):
135 | meta = {}
136 | # The following metadata maps contiguous id from [0, #thing categories +
137 | # #stuff categories) to their names and colors. We have to replica of the
138 | # same name and color under "thing_*" and "stuff_*" because the current
139 | # visualization function in D2 handles thing and class classes differently
140 | # due to some heuristic used in Panoptic FPN. We keep the same naming to
141 | # enable reusing existing visualization functions.
142 | thing_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
143 | thing_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
144 | stuff_classes = [k["name"] for k in CITYSCAPES_CATEGORIES]
145 | stuff_colors = [k["color"] for k in CITYSCAPES_CATEGORIES]
146 |
147 | meta["thing_classes"] = thing_classes
148 | meta["thing_colors"] = thing_colors
149 | meta["stuff_classes"] = stuff_classes
150 | meta["stuff_colors"] = stuff_colors
151 |
152 | # There are three types of ids in cityscapes panoptic segmentation:
153 | # (1) category id: like semantic segmentation, it is the class id for each
154 | # pixel. Since there are some classes not used in evaluation, the category
155 | # id is not always contiguous and thus we have two set of category ids:
156 | # - original category id: category id in the original dataset, mainly
157 | # used for evaluation.
158 | # - contiguous category id: [0, #classes), in order to train the classifier
159 | # (2) instance id: this id is used to differentiate different instances from
160 | # the same category. For "stuff" classes, the instance id is always 0; for
161 | # "thing" classes, the instance id starts from 1 and 0 is reserved for
162 | # ignored instances (e.g. crowd annotation).
163 | # (3) panoptic id: this is the compact id that encode both category and
164 | # instance id by: category_id * 1000 + instance_id.
165 | thing_dataset_id_to_contiguous_id = {}
166 | stuff_dataset_id_to_contiguous_id = {}
167 |
168 | for k in CITYSCAPES_CATEGORIES:
169 | if k["isthing"] == 1:
170 | thing_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
171 | else:
172 | stuff_dataset_id_to_contiguous_id[k["id"]] = k["trainId"]
173 |
174 | meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id
175 | meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id
176 |
177 | for key, (image_dir, gt_dir, gt_json) in _RAW_CITYSCAPES_PANOPTIC_SPLITS.items():
178 | image_dir = os.path.join(root, image_dir)
179 | gt_dir = os.path.join(root, gt_dir)
180 | gt_json = os.path.join(root, gt_json)
181 |
182 | DatasetCatalog.register(
183 | key, lambda x=image_dir, y=gt_dir, z=gt_json: load_cityscapes_panoptic(x, y, z, meta)
184 | )
185 | MetadataCatalog.get(key).set(
186 | panoptic_root=gt_dir,
187 | image_root=image_dir,
188 | panoptic_json=gt_json,
189 | gt_dir=gt_dir.replace("cityscapes_panoptic_", ""),
190 | evaluator_type="cityscapes_panoptic_seg",
191 | ignore_label=255,
192 | label_divisor=1000,
193 | **meta,
194 | )
195 |
196 | _root = os.getenv("DETECTRON2_DATASETS", "datasets")
197 | register_all_cityscapes_panoptic(_root)
--------------------------------------------------------------------------------
/osprey/mm_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from io import BytesIO
3 | import base64
4 |
5 | import torch
6 | from transformers import StoppingCriteria
7 | from osprey.constants import IMAGE_TOKEN_INDEX
8 |
9 |
10 | def load_image_from_base64(image):
11 | return Image.open(BytesIO(base64.b64decode(image)))
12 |
13 | def expand2square(pil_img, background_color):
14 | width, height = pil_img.size
15 | if width == height:
16 | return pil_img
17 | elif width > height:
18 | result = Image.new(pil_img.mode, (width, width), background_color)
19 | result.paste(pil_img, (0, (width - height) // 2))
20 | return result
21 | else:
22 | result = Image.new(pil_img.mode, (height, height), background_color)
23 | result.paste(pil_img, ((height - width) // 2, 0))
24 | return result
25 |
26 |
27 | def process_images(images, image_processor, model_cfg):
28 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
29 | new_images = []
30 | if image_aspect_ratio == 'pad':
31 | for image in images:
32 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
33 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
34 | new_images.append(image)
35 | else:
36 | return image_processor(images, return_tensors='pt')['pixel_values']
37 | if all(x.shape == new_images[0].shape for x in new_images):
38 | new_images = torch.stack(new_images, dim=0)
39 | return new_images
40 |
41 |
42 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
43 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
44 |
45 | def insert_separator(X, sep):
46 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
47 |
48 | input_ids = []
49 | offset = 0
50 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
51 | offset = 1
52 | input_ids.append(prompt_chunks[0][0])
53 |
54 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
55 | input_ids.extend(x[offset:])
56 |
57 | if return_tensors is not None:
58 | if return_tensors == 'pt':
59 | return torch.tensor(input_ids, dtype=torch.long)
60 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
61 | return input_ids
62 |
63 |
64 | def get_model_name_from_path(model_path):
65 | model_path = model_path.strip("/")
66 | model_paths = model_path.split("/")
67 | if model_paths[-1].startswith('checkpoint-'):
68 | return model_paths[-2] + "_" + model_paths[-1]
69 | else:
70 | return model_paths[-1]
71 |
72 | class KeywordsStoppingCriteria(StoppingCriteria):
73 | def __init__(self, keywords, tokenizer, input_ids):
74 | self.keywords = keywords
75 | self.keyword_ids = []
76 | for keyword in keywords:
77 | cur_keyword_ids = tokenizer(keyword).input_ids
78 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
79 | cur_keyword_ids = cur_keyword_ids[1:]
80 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
81 | self.tokenizer = tokenizer
82 | self.start_len = input_ids.shape[1]
83 |
84 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
85 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
86 | offset = min(output_ids.shape[1] - self.start_len, 3)
87 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
88 | for keyword_id in self.keyword_ids:
89 | if output_ids[0, -keyword_id.shape[0]:] == keyword_id:
90 | return True
91 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
92 | for keyword in self.keywords:
93 | if keyword in outputs:
94 | return True
95 | return False
96 |
--------------------------------------------------------------------------------
/osprey/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .language_model.osprey_llama import OspreyLlamaForCausalLM, OspreyConfig
2 |
--------------------------------------------------------------------------------
/osprey/model/consolidate.py:
--------------------------------------------------------------------------------
1 |
2 | import argparse
3 |
4 | import torch
5 | from transformers import AutoTokenizer, AutoModelForCausalLM
6 | from osprey.model import *
7 | from osprey.model.utils import auto_upgrade
8 |
9 |
10 | def consolidate_ckpt(src_path, dst_path):
11 | print("Loading model")
12 | auto_upgrade(src_path)
13 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
14 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
15 | src_model.save_pretrained(dst_path)
16 | src_tokenizer.save_pretrained(dst_path)
17 |
18 |
19 | if __name__ == "__main__":
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument("--src", type=str, required=True)
22 | parser.add_argument("--dst", type=str, required=True)
23 |
24 | args = parser.parse_args()
25 |
26 | consolidate_ckpt(args.src, args.dst)
27 |
--------------------------------------------------------------------------------
/osprey/model/language_model/osprey_llama.py:
--------------------------------------------------------------------------------
1 | # Licensed under the Apache License, Version 2.0 (the "License");
2 | # you may not use this file except in compliance with the License.
3 | # You may obtain a copy of the License at
4 | #
5 | # http://www.apache.org/licenses/LICENSE-2.0
6 | #
7 | # Unless required by applicable law or agreed to in writing, software
8 | # distributed under the License is distributed on an "AS IS" BASIS,
9 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 |
13 |
14 | from typing import List, Optional, Tuple, Union
15 |
16 | import torch
17 | import torch.nn as nn
18 | from torch.nn import CrossEntropyLoss
19 |
20 | from transformers import AutoConfig, AutoModelForCausalLM, \
21 | LlamaConfig, LlamaModel, LlamaForCausalLM
22 |
23 | from transformers.modeling_outputs import CausalLMOutputWithPast
24 |
25 | from ..osprey_arch import OspreyMetaModel, OspreyMetaForCausalLM
26 |
27 | from ..layer import MaskExtractor
28 |
29 | class OspreyConfig(LlamaConfig):
30 | model_type = "osprey"
31 |
32 |
33 | class OspreyLlamaModel(OspreyMetaModel, LlamaModel):
34 | config_class = OspreyConfig
35 |
36 | def __init__(self, config: LlamaConfig):
37 | super(OspreyLlamaModel, self).__init__(config)
38 |
39 |
40 | class OspreyLlamaForCausalLM(LlamaForCausalLM, OspreyMetaForCausalLM):
41 | config_class = OspreyConfig
42 |
43 | def __init__(self, config):
44 | super(LlamaForCausalLM, self).__init__(config)
45 | self.model = OspreyLlamaModel(config)
46 |
47 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
48 |
49 | self.mask_extractor = MaskExtractor()
50 |
51 | # Initialize weights and apply final processing
52 | self.post_init()
53 |
54 | def get_model(self):
55 | return self.model
56 |
57 | def forward(
58 | self,
59 | input_ids: torch.LongTensor = None,
60 | attention_mask: Optional[torch.Tensor] = None,
61 | img_metas = None,
62 | masks = None,
63 | past_key_values: Optional[List[torch.FloatTensor]] = None,
64 | inputs_embeds: Optional[torch.FloatTensor] = None,
65 | labels: Optional[torch.LongTensor] = None,
66 | use_cache: Optional[bool] = None,
67 | output_attentions: Optional[bool] = None,
68 | output_hidden_states: Optional[bool] = None,
69 | images: Optional[torch.FloatTensor] = None,
70 | return_dict: Optional[bool] = None,
71 | ) -> Union[Tuple, CausalLMOutputWithPast]:
72 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
73 | output_hidden_states = (
74 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
75 | )
76 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
77 |
78 | input_token_len = input_ids.shape[1]
79 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, masks, attention_mask, past_key_values, labels, images)
80 |
81 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
82 |
83 | if inputs_embeds is not None:
84 | inputs_embeds = inputs_embeds.bfloat16()
85 |
86 | self.model = self.model.bfloat16()
87 |
88 | outputs = self.model(
89 | input_ids=input_ids,
90 | attention_mask=attention_mask,
91 | past_key_values=past_key_values,
92 | inputs_embeds=inputs_embeds,
93 | use_cache=use_cache,
94 | output_attentions=output_attentions,
95 | output_hidden_states=output_hidden_states,
96 | return_dict=return_dict
97 | )
98 |
99 | hidden_states = outputs[0]
100 | self.lm_head = self.lm_head.to(hidden_states.dtype)
101 | logits = self.lm_head(hidden_states)
102 |
103 | loss = None
104 | if labels is not None:
105 | # Shift so that tokens < n predict n
106 | shift_logits = logits[..., :-1, :].contiguous()
107 | shift_labels = labels[..., 1:].contiguous()
108 | # Flatten the tokens
109 | loss_fct = CrossEntropyLoss()
110 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
111 | shift_labels = shift_labels.view(-1)
112 | # Enable model/pipeline parallelism
113 | shift_labels = shift_labels.to(shift_logits.device)
114 | loss = loss_fct(shift_logits, shift_labels)
115 |
116 | if not return_dict:
117 | output = (logits,) + outputs[1:]
118 |
119 | return (loss,) + output if loss is not None else output
120 |
121 | return CausalLMOutputWithPast(
122 | loss=loss,
123 | logits=logits,
124 | past_key_values=outputs.past_key_values,
125 | hidden_states=outputs.hidden_states,
126 | attentions=outputs.attentions,
127 | )
128 |
129 | def prepare_inputs_for_generation(
130 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
131 | ):
132 | if past_key_values:
133 | input_ids = input_ids[:, -1:]
134 |
135 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
136 | if inputs_embeds is not None and past_key_values is None:
137 | model_inputs = {"inputs_embeds": inputs_embeds}
138 | else:
139 | model_inputs = {"input_ids": input_ids}
140 |
141 | model_inputs.update(
142 | {
143 | "past_key_values": past_key_values,
144 | "use_cache": kwargs.get("use_cache"),
145 | "attention_mask": attention_mask,
146 | "images": kwargs.get("images", None),
147 | }
148 | )
149 | return model_inputs
150 |
151 | AutoConfig.register("osprey", OspreyConfig)
152 | AutoModelForCausalLM.register(OspreyConfig, OspreyLlamaForCausalLM)
153 |
--------------------------------------------------------------------------------
/osprey/model/layer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class MLP(nn.Module):
7 |
8 | def __init__(self, input_dim: int, hidden_dim: int, output_dim: int,
9 | num_layers: int) -> None:
10 | super().__init__()
11 | self.num_layers = num_layers
12 | h = [hidden_dim] * (num_layers - 1)
13 | self.layers = nn.ModuleList(
14 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
15 |
16 | def forward(self, x):
17 | for i, layer in enumerate(self.layers):
18 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
19 | return x
20 |
21 |
22 | class MaskExtractor(nn.Module):
23 | def __init__(self, mask_shape=112, embed_dim=1024, out_dim=4096):
24 | super(MaskExtractor, self).__init__()
25 | self.mask_shape = mask_shape
26 | self.mask_pooling = MaskPooling()
27 | self.feat_linear = nn.Linear(embed_dim, out_dim)
28 | self.mask_linear = MLP(mask_shape*mask_shape, embed_dim, out_dim, 3)
29 | # self.res_linear = {}
30 | self.feature_name = ['res2', 'res3', 'res4', 'res5']
31 |
32 | # for i, feat in enumerate(self.feature_name):
33 | # self.res_linear[feat] = nn.Linear(192*(2**i), embed_dim)
34 |
35 | self.res2 = nn.Linear(192, 1024)
36 | self.res3 = nn.Linear(384, 1024)
37 | self.res4 = nn.Linear(768, 1024)
38 | self.res5 = nn.Linear(1536, 1024)
39 |
40 | def forward(self, feats, masks):
41 | query_feats = []
42 | pos_feats = []
43 | if masks is None:
44 | return query_feats, pos_feats
45 |
46 | num_imgs = len(masks)
47 |
48 | for idx in range(num_imgs):
49 | mask = masks[idx].unsqueeze(0).float()
50 |
51 | num_feats = len(self.feature_name)
52 | mask_feats = mask.new_zeros(num_feats, mask.shape[1], 1024)
53 | for i, name in enumerate(self.feature_name):
54 | feat = feats[name][idx].unsqueeze(0)
55 |
56 | raw_dtype = feat.dtype
57 | feat = feat.to(mask.dtype)
58 | mask_feat_raw = self.mask_pooling(feat, mask)
59 |
60 | mask_feat_flatten = mask_feat_raw.reshape(-1, mask_feat_raw.shape[-1])
61 |
62 | # self.res_linear[name] = self.res_linear[name].to(dtype=mask_feat_flatten.dtype, device=mask_feat_flatten.device)
63 | # mask_feat = self.res_linear[name](mask_feat_flatten)
64 |
65 | if name=='res2':
66 | self.res2 = self.res2.to(device=mask_feat_flatten.device, dtype=mask_feat_flatten.dtype)
67 | mask_feat = self.res2(mask_feat_flatten)
68 | elif name=='res3':
69 | self.res3 = self.res3.to(device=mask_feat_flatten.device, dtype=mask_feat_flatten.dtype)
70 | mask_feat = self.res3(mask_feat_flatten)
71 | elif name=='res4':
72 | self.res4 = self.res4.to(device=mask_feat_flatten.device, dtype=mask_feat_flatten.dtype)
73 | mask_feat = self.res4(mask_feat_flatten)
74 | else:
75 | self.res5 = self.res5.to(device=mask_feat_flatten.device, dtype=mask_feat_flatten.dtype)
76 | mask_feat = self.res5(mask_feat_flatten)
77 |
78 | mask_feat = mask_feat.reshape(*mask_feat_raw.shape[:2], -1)
79 | mask_feat = mask_feat.to(raw_dtype)
80 |
81 | mask_feats[i] = mask_feat[0]
82 | mask_feats = mask_feats.sum(0)
83 | self.feat_linear = self.feat_linear.to(dtype=mask_feats.dtype, device=mask_feats.device)
84 | mask_feats_linear = self.feat_linear(mask_feats)
85 | query_feats.append(mask_feats_linear)
86 |
87 | # position
88 | mask = F.interpolate(mask, size=self.mask_shape, mode='bilinear', align_corners=False)
89 | self.mask_linear = self.mask_linear.to(dtype=mask.dtype, device=mask.device)
90 | pos_feat = self.mask_linear(mask.reshape(mask.shape[1], -1))
91 | pos_feats.append(pos_feat)
92 |
93 | return query_feats, pos_feats
94 |
95 |
96 | class MaskPooling(nn.Module):
97 | def __init__(self):
98 | super().__init__()
99 |
100 | def forward(self, x, mask):
101 |
102 | if not x.shape[-2:] == mask.shape[-2:]:
103 | # reshape mask to x
104 | mask = F.interpolate(mask, size=x.shape[-2:], mode='bilinear', align_corners=False)
105 |
106 | b, c, h ,w = x.shape
107 | b, q, h, w = mask.shape
108 | mask = (mask > 0).to(mask.dtype)
109 | denorm = mask.sum(dim=(-1, -2), keepdim=True) + 1e-8
110 |
111 | mask_pooled_x = torch.einsum(
112 | "bchw,bqhw->bqc",
113 | x,
114 | mask / denorm,
115 | )
116 | return mask_pooled_x
117 |
--------------------------------------------------------------------------------
/osprey/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower
3 |
4 |
5 | def build_vision_tower(vision_tower_cfg, delay_load=False):
6 |
7 | return CLIPVisionTower(args=vision_tower_cfg)
8 |
9 | raise ValueError(f'Unknown vision tower: {vision_tower}')
10 |
--------------------------------------------------------------------------------
/osprey/model/multimodal_encoder/clip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 |
5 | from open_clip.model import _build_vision_tower
6 |
7 |
8 | class CLIP(nn.Module):
9 | def __init__(self):
10 | super().__init__()
11 | model_name = 'convnext_large'
12 |
13 | vision_cfg = {'timm_model_name': model_name, 'timm_model_pretrained': False, 'timm_pool': '', 'timm_proj': 'mlp', 'timm_drop': 0.0, 'timm_drop_path': 0.1, 'image_size': 320}
14 | self.visual = _build_vision_tower(embed_dim=768, vision_cfg=vision_cfg, quick_gelu=False)
15 |
16 | self.eval()
17 | self.freeze_everything()
18 |
19 | def freeze_everything(self):
20 | for param in self.visual.parameters():
21 | param.requires_grad = False
22 |
23 | def extract_features(self, x):
24 | out = {}
25 | x = x.to(self.visual.trunk.stem.state_dict()['1.bias'].dtype)
26 | x = self.visual.trunk.stem(x)
27 | out['stem'] = x.contiguous()
28 | for i in range(4):
29 | x = self.visual.trunk.stages[i](x)
30 | out[f'res{i+2}'] = x.contiguous()
31 |
32 | x = self.visual.trunk.norm_pre(x)
33 | out['clip_vis_dense'] = x.contiguous()
34 | return out
35 |
36 | def forward(self, x):
37 | self.eval()
38 | with torch.no_grad():
39 | return self.extract_features(x)
40 |
--------------------------------------------------------------------------------
/osprey/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import CLIPImageProcessor
5 | from .clip import CLIP
6 |
7 | class CLIPVisionTower(nn.Module):
8 | def __init__(self, args, img_size=512, delay_load=False):
9 | super().__init__()
10 |
11 | # test
12 | if hasattr(args, 'mm_vision_tower'):
13 | self.clip_model = args.mm_vision_tower
14 | else: # train
15 | self.clip_model = args.vision_tower
16 | self.is_loaded = False
17 | self.img_size = img_size
18 |
19 | if not delay_load:
20 | self.load_model()
21 |
22 | def load_model(self):
23 | self.image_processor = CLIPImageProcessor(do_resize=True, size={"shortest_edge":self.img_size}, resample=3, do_center_crop=True, crop_size={"height": self.img_size, "width": self.img_size},
24 | do_rescale=True, rescale_factor=0.00392156862745098, do_normalize=True, image_mean=[0.48145466, 0.4578275, 0.40821073],
25 | image_std=[0.26862954, 0.26130258, 0.27577711], do_convert_rgb=True, )
26 |
27 | self.vision_tower = CLIP()
28 |
29 | self.vision_tower.load_state_dict(torch.load(self.clip_model),strict=False)
30 |
31 | self.is_loaded = True
32 |
33 | @torch.no_grad()
34 | def forward(self, images):
35 |
36 | if type(images) is list:
37 | image_features = []
38 | image_features_dict = []
39 | for image in images:
40 | image_feature_dict = self.vision_tower(image.unsqueeze(0))
41 | image_features_dict.append(image_feature_dict)
42 | image_feature = image_feature_dict['res4']
43 | image_feature = image_feature.reshape(*image_feature.shape[:2],-1).permute(0,2,1)
44 | image_features.append(image_feature)
45 | else:
46 | image_features_dict = self.vision_tower(images)
47 | image_features = image_features_dict['res4']
48 | image_features = image_features.reshape(*image_features.shape[:2],-1).permute(0,2,1)
49 |
50 | return image_features, image_features_dict
51 |
52 | @property
53 | def dtype(self):
54 | return self.vision_tower.dtype
55 |
56 | @property
57 | def device(self):
58 | return self.vision_tower.device
59 |
--------------------------------------------------------------------------------
/osprey/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 | def forward(self, x):
29 | x = self.pre_norm(x)
30 | return x + self.proj(x)
31 |
32 |
33 | def build_vision_projector(config, delay_load=False, **kwargs):
34 | mm_hidden_size = getattr(config, 'mm_hidden_size', 768)
35 | projector_type = getattr(config, 'mm_projector_type', 'linear')
36 |
37 | if projector_type == 'linear':
38 | return nn.Linear(mm_hidden_size, config.hidden_size)
39 |
40 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
41 | if mlp_gelu_match:
42 | mlp_depth = int(mlp_gelu_match.group(1))
43 | modules = [nn.Linear(mm_hidden_size, config.hidden_size)]
44 | for _ in range(1, mlp_depth):
45 | modules.append(nn.GELU())
46 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
47 | return nn.Sequential(*modules)
48 |
49 | if projector_type == 'identity':
50 | return IdentityMap()
51 |
52 | raise ValueError(f'Unknown projector type: {projector_type}')
53 |
--------------------------------------------------------------------------------
/osprey/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import warnings
3 |
4 | import torch
5 |
6 | import transformers
7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8 |
9 | try:
10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11 | except ImportError:
12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 |
16 | def forward(
17 | self,
18 | hidden_states: torch.Tensor,
19 | attention_mask: Optional[torch.Tensor] = None,
20 | position_ids: Optional[torch.Tensor] = None,
21 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
22 | output_attentions: bool = False,
23 | use_cache: bool = False,
24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
25 | if output_attentions:
26 | warnings.warn(
27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
28 | )
29 |
30 | bsz, q_len, _ = hidden_states.size()
31 | # print("begin_#")
32 | query_states = (
33 | self.q_proj(hidden_states)
34 | .view(bsz, q_len, self.num_heads, self.head_dim)
35 | .transpose(1, 2)
36 | )
37 | # print("OK_#")
38 | key_states = (
39 | self.k_proj(hidden_states)
40 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
41 | .transpose(1, 2)
42 | )
43 | value_states = (
44 | self.v_proj(hidden_states)
45 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
46 | .transpose(1, 2)
47 | ) # shape: (b, num_heads, s, head_dim)
48 |
49 | kv_seq_len = key_states.shape[-2]
50 | if past_key_value is not None:
51 | kv_seq_len += past_key_value[0].shape[-2]
52 |
53 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
54 | query_states, key_states = apply_rotary_pos_emb(
55 | query_states, key_states, cos, sin, position_ids
56 | )
57 |
58 | if past_key_value is not None:
59 | # reuse k, v
60 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
61 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
62 |
63 | past_key_value = (key_states, value_states) if use_cache else None
64 |
65 | # repeat k/v heads if n_kv_heads < n_heads
66 | key_states = repeat_kv(key_states, self.num_key_value_groups)
67 | value_states = repeat_kv(value_states, self.num_key_value_groups)
68 |
69 | # Transform the data into the format required by flash attention
70 | qkv = torch.stack([query_states, key_states, value_states], dim=2)
71 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
72 | key_padding_mask = attention_mask
73 |
74 | if key_padding_mask is None:
75 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
76 | cu_q_lens = torch.arange(
77 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
78 | )
79 | max_s = q_len
80 | output = flash_attn_unpadded_qkvpacked_func(
81 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
82 | )
83 | output = output.view(bsz, q_len, -1)
84 | else:
85 | qkv = qkv.reshape(bsz, q_len, -1)
86 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
87 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
88 | output_unpad = flash_attn_unpadded_qkvpacked_func(
89 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
90 | )
91 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
92 | output = pad_input(output_unpad, indices, bsz, q_len)
93 |
94 | return self.o_proj(output), None, past_key_value
95 |
96 |
97 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
98 | # requires the attention mask to be the same as the key_padding_mask
99 | def _prepare_decoder_attention_mask(
100 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
101 | ):
102 | # [bsz, seq_len]
103 | return attention_mask
104 |
105 |
106 | def replace_llama_attn_with_flash_attn():
107 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
108 | if cuda_major < 8:
109 | warnings.warn(
110 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
111 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
112 | )
113 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
114 | _prepare_decoder_attention_mask
115 | )
116 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
117 |
--------------------------------------------------------------------------------
/osprey/train/osprey_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | from torch.utils.data import Sampler
6 |
7 | from transformers import Trainer
8 | from transformers.trainer import (
9 | has_length,
10 | )
11 | from typing import List, Optional
12 |
13 |
14 | def maybe_zero_3(param, ignore_status=False, name=None):
15 | from deepspeed import zero
16 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
17 | if hasattr(param, "ds_id"):
18 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
19 | if not ignore_status:
20 | print(name, 'no ignore status')
21 | with zero.GatheredParameters([param]):
22 | param = param.data.detach().cpu().clone()
23 | else:
24 | param = param.detach().cpu().clone()
25 | return param
26 |
27 |
28 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
29 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
30 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
31 | return to_return
32 |
33 |
34 | def split_to_even_chunks(indices, lengths, num_chunks):
35 | """
36 | Split a list of indices into `chunks` chunks of roughly equal lengths.
37 | """
38 |
39 | if len(indices) % num_chunks != 0:
40 | return [indices[i::num_chunks] for i in range(num_chunks)]
41 |
42 | num_indices_per_chunk = len(indices) // num_chunks
43 |
44 | chunks = [[] for _ in range(num_chunks)]
45 | chunks_lengths = [0 for _ in range(num_chunks)]
46 | for index in indices:
47 | shortest_chunk = chunks_lengths.index(min(chunks_lengths))
48 | chunks[shortest_chunk].append(index)
49 | chunks_lengths[shortest_chunk] += lengths[index]
50 | if len(chunks[shortest_chunk]) == num_indices_per_chunk:
51 | chunks_lengths[shortest_chunk] = float("inf")
52 |
53 | return chunks
54 |
55 |
56 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
57 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
58 | assert all(l != 0 for l in lengths), "Should not have zero length."
59 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
60 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
61 |
62 | assert len(mm_indices) > 0, "Should have at least one multimodal sample."
63 | assert len(lang_indices) > 0, "Should have at least one language sample."
64 |
65 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
66 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
67 | megabatch_size = world_size * batch_size
68 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
69 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
70 |
71 | last_mm = mm_megabatches[-1]
72 | last_lang = lang_megabatches[-1]
73 | additional_batch = last_mm + last_lang
74 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
75 | megabatch_indices = torch.randperm(len(megabatches), generator=generator)
76 | megabatches = [megabatches[i] for i in megabatch_indices]
77 |
78 | if len(additional_batch) >= megabatch_size:
79 | megabatches = [additional_batch[:megabatch_size]] + megabatches
80 | additional_batch = additional_batch[megabatch_size:]
81 |
82 | if len(additional_batch) > 0:
83 | megabatches.append(additional_batch)
84 |
85 | return [i for megabatch in megabatches for i in megabatch]
86 |
87 |
88 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
89 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
90 | indices = torch.randperm(len(lengths), generator=generator)
91 | megabatch_size = world_size * batch_size
92 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
93 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
94 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
95 |
96 | return [i for megabatch in megabatches for batch in megabatch for i in batch]
97 |
98 |
99 | class LengthGroupedSampler(Sampler):
100 | r"""
101 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
102 | keeping a bit of randomness.
103 | """
104 |
105 | def __init__(
106 | self,
107 | batch_size: int,
108 | world_size: int,
109 | lengths: Optional[List[int]] = None,
110 | generator=None,
111 | group_by_modality: bool = False,
112 | ):
113 | if lengths is None:
114 | raise ValueError("Lengths must be provided.")
115 |
116 | self.batch_size = batch_size
117 | self.world_size = world_size
118 | self.lengths = lengths
119 | self.generator = generator
120 | self.group_by_modality = group_by_modality
121 |
122 | def __len__(self):
123 | return len(self.lengths)
124 |
125 | def __iter__(self):
126 | if self.group_by_modality:
127 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
128 | else:
129 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
130 | return iter(indices)
131 |
132 |
133 | class OspreyTrainer(Trainer):
134 |
135 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
136 | if self.train_dataset is None or not has_length(self.train_dataset):
137 | return None
138 |
139 | if self.args.group_by_modality_length:
140 | lengths = self.train_dataset.modality_lengths
141 | return LengthGroupedSampler(
142 | # self.args.train_batch_size * self.args.gradient_accumulation_steps, # TODO: seems that we should not have gradient_accumulation_steps
143 | self.args.train_batch_size,
144 | world_size=self.args.world_size,
145 | lengths=lengths,
146 | group_by_modality=True,
147 | )
148 | else:
149 | return super()._get_train_sampler()
150 |
151 | def _save_checkpoint(self, model, trial, metrics=None):
152 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
153 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
154 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
155 |
156 | run_dir = self._get_output_dir(trial=trial)
157 | output_dir = os.path.join(run_dir, checkpoint_folder)
158 |
159 | # Only save Adapter
160 | keys_to_match = ['mm_projector', 'vision_resampler']
161 | if getattr(self.args, "use_im_start_end", False):
162 | keys_to_match.extend(['embed_tokens', 'embed_in'])
163 |
164 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
165 |
166 | if self.args.local_rank == 0 or self.args.local_rank == -1:
167 | self.model.config.save_pretrained(output_dir)
168 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
169 | else:
170 | super(OspreyTrainer, self)._save_checkpoint(model, trial, metrics)
171 |
172 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
173 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
174 | pass
175 | else:
176 | super(OspreyTrainer, self)._save(output_dir, state_dict)
177 |
--------------------------------------------------------------------------------
/osprey/train/train_mem.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4 |
5 | # Need to call this before importing transformers.
6 | from osprey.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7 |
8 | replace_llama_attn_with_flash_attn()
9 |
10 | from osprey.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/osprey/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from osprey.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True)
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "osprey"
7 | version = "1.0"
8 | description = "Pixel Understanding with Visual Instruction Tuning."
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "torch==2.0.1", "torchvision==0.15.2",
17 | "transformers==4.31.0",
18 | "einops", "fastapi", "markdown2[all]", "numpy",
19 | "requests", "sentencepiece", "tokenizers>=0.12.1",
20 | "uvicorn", "tensorboard", "open_clip_torch",
21 | "shortuuid", "httpx==0.24.0",
22 | "deepspeed==0.9.5",
23 | "peft==0.4.0",
24 | "transformers==4.31.0",
25 | "accelerate==0.21.0",
26 | "bitsandbytes==0.41.0",
27 | "scikit-learn==1.2.2",
28 | "sentencepiece==0.1.99",
29 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
30 | "gradio_client==0.2.9",
31 | "pycocotools", "terminaltables", "lvis"
32 | ]
33 |
34 | [tool.setuptools.packages.find]
35 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
36 |
37 | [tool.wheel]
38 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
39 |
--------------------------------------------------------------------------------
/scripts/stage2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=`pwd`:$PYTHONPATH
3 | export TRAIN_MASK_MODULE=1
4 |
5 | deepspeed --include localhost:0,1,2,3 osprey/train/train_mem.py \
6 | --deepspeed ./scripts/zero2.json \
7 | --model_name_or_path vicuna-7b-v1.5 \
8 | --dataset_config ./osprey/configs/stage2.json \
9 | --version v1 \
10 | --vision_tower laion2b_s29b_b131k_ft_soup.bin \
11 | --pretrain_mm_mlp_adapter osprey-v1.0-7b-pretrain/mm_projector.bin \
12 | --mm_projector_type mlp2x_gelu \
13 | --mm_vision_select_layer -2 \
14 | --mm_use_im_start_end False \
15 | --mm_use_im_patch_token False \
16 | --image_aspect_ratio pad \
17 | --group_by_modality_length True \
18 | --bf16 True \
19 | --output_dir './exp/stage2' \
20 | --num_train_epochs 2 \
21 | --per_device_train_batch_size 1\
22 | --per_device_eval_batch_size 4 \
23 | --gradient_accumulation_steps 1 \
24 | --evaluation_strategy "no" \
25 | --save_strategy "steps" \
26 | --save_steps 10000 \
27 | --save_total_limit 1 \
28 | --learning_rate 2e-5 \
29 | --weight_decay 0. \
30 | --warmup_ratio 0.03 \
31 | --lr_scheduler_type "cosine" \
32 | --logging_steps 1 \
33 | --tf32 True \
34 | --model_max_length 2048 \
35 | --gradient_checkpointing True \
36 | --dataloader_num_workers 4 \
37 | --lazy_preprocess True \
38 | --report_to "none" \
39 | --group_by_modality_length False
40 |
--------------------------------------------------------------------------------
/scripts/stage3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export PYTHONPATH=`pwd`:$PYTHONPATH
3 |
4 | deepspeed --include localhost:0,1,2,3 osprey/train/train_mem.py \
5 | --deepspeed ./scripts/zero2.json \
6 | --model_name_or_path /exp/stage2/checkpoint-final \
7 | --dataset_config ./osprey/configs/stage3.json \
8 | --version v1 \
9 | --vision_tower laion2b_s29b_b131k_ft_soup.bin \
10 | --mm_projector_type mlp2x_gelu \
11 | --mm_vision_select_layer -2 \
12 | --mm_use_im_start_end False \
13 | --mm_use_im_patch_token False \
14 | --image_aspect_ratio pad \
15 | --group_by_modality_length True \
16 | --bf16 True \
17 | --output_dir './exp/stage3' \
18 | --num_train_epochs 2 \
19 | --per_device_train_batch_size 1\
20 | --per_device_eval_batch_size 4 \
21 | --gradient_accumulation_steps 1 \
22 | --evaluation_strategy "no" \
23 | --save_strategy "steps" \
24 | --save_steps 2000 \
25 | --save_total_limit 1 \
26 | --learning_rate 1e-5 \
27 | --weight_decay 0. \
28 | --warmup_ratio 0.03 \
29 | --lr_scheduler_type "cosine" \
30 | --logging_steps 1 \
31 | --tf32 True \
32 | --model_max_length 2048 \
33 | --gradient_checkpointing True \
34 | --dataloader_num_workers 4 \
35 | --lazy_preprocess True \
36 | --report_to "none" \
37 | --group_by_modality_length False
38 |
--------------------------------------------------------------------------------
/scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto"
22 | }
23 | }
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e8,
25 | "stage3_max_reuse_distance": 1e8,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
--------------------------------------------------------------------------------
/scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "scheduler": {
23 | "type": "WarmupLR",
24 | "params": {
25 | "warmup_min_lr": "auto",
26 | "warmup_max_lr": "auto",
27 | "warmup_num_steps": "auto"
28 | }
29 | },
30 | "zero_optimization": {
31 | "stage": 3,
32 | "offload_optimizer": {
33 | "device": "cpu",
34 | "pin_memory": true
35 | },
36 | "offload_param": {
37 | "device": "cpu",
38 | "pin_memory": true
39 | },
40 | "overlap_comm": true,
41 | "contiguous_gradients": true,
42 | "sub_group_size": 1e9,
43 | "reduce_bucket_size": "auto",
44 | "stage3_prefetch_bucket_size": "auto",
45 | "stage3_param_persistence_threshold": "auto",
46 | "stage3_max_live_parameters": 1e8,
47 | "stage3_max_reuse_distance": 1e8,
48 | "gather_16bit_weights_on_model_save": true
49 | },
50 | "gradient_accumulation_steps": "auto",
51 | "gradient_clipping": "auto",
52 | "train_batch_size": "auto",
53 | "train_micro_batch_size_per_gpu": "auto",
54 | "steps_per_print": 1e5,
55 | "wall_clock_breakdown": false
56 | }
--------------------------------------------------------------------------------