├── LICENSE
├── README.md
├── docs
├── .jekyll-cache
│ └── Jekyll
│ │ └── Cache
│ │ ├── Jekyll--Cache
│ │ └── b7
│ │ │ └── 9606fb3afea5bd1609ed40b622142f1c98125abcfe89a76a661b0e8e343910
│ │ └── Jekyll--Converters--Markdown
│ │ └── fd
│ │ └── aa4200cef2544ec568b25970b3a49cf9b4f3d7015485506ccda914caf9c5b9
├── 404.html
├── Gemfile
├── Gemfile.lock
├── LICENSE
├── _config.yml
├── _data
│ └── authors.yml
├── _layouts
│ └── default.html
├── _posts
│ └── 2021-01-16-welcome-to-jekyll.markdown
├── _sass
│ └── main.scss
├── _site
│ ├── 404.html
│ ├── LICENSE
│ ├── about
│ │ └── index.html
│ ├── assets
│ │ ├── css
│ │ │ ├── styles.css
│ │ │ └── styles.css.map
│ │ ├── favicon
│ │ │ ├── android-chrome-192x192.png
│ │ │ ├── android-chrome-512x512.png
│ │ │ ├── apple-touch-icon.png
│ │ │ ├── favicon-16x16.png
│ │ │ ├── favicon-32x32.png
│ │ │ ├── favicon.ico
│ │ │ └── site.webmanifest
│ │ ├── img
│ │ │ ├── animation.gif
│ │ │ ├── criteria.png
│ │ │ ├── datarow.png
│ │ │ ├── interface.png
│ │ │ ├── jamin.jpg
│ │ │ ├── juho.jpg
│ │ │ ├── kaist_logo.png
│ │ │ ├── kixlab_logo.png
│ │ │ ├── naver_logo.png
│ │ │ ├── taesoo.jpeg
│ │ │ ├── yoonjoo.jpeg
│ │ │ └── youngho.jpg
│ │ ├── main.css
│ │ ├── main.css.map
│ │ └── minima-social-icons.svg
│ ├── feed.xml
│ ├── index.html
│ └── jekyll
│ │ └── update
│ │ └── 2021
│ │ └── 01
│ │ └── 16
│ │ └── welcome-to-jekyll.html
├── about.markdown
├── assets
│ ├── css
│ │ └── styles.scss
│ ├── favicon
│ │ ├── android-chrome-192x192.png
│ │ ├── android-chrome-512x512.png
│ │ ├── apple-touch-icon.png
│ │ ├── favicon-16x16.png
│ │ ├── favicon-32x32.png
│ │ ├── favicon.ico
│ │ └── site.webmanifest
│ └── img
│ │ ├── blue_fire.png
│ │ ├── geewook.jpg
│ │ ├── human_corr.svg
│ │ ├── instructionfollowing_results.png
│ │ ├── kaist_logo.png
│ │ ├── lklab_logo.jpg
│ │ ├── minjoon.jpeg
│ │ ├── naver_ai_lab_logo.png
│ │ ├── naver_cloud_logo.png
│ │ ├── naver_logo.png
│ │ ├── pairwise_win_rate.svg
│ │ ├── perception_collection_stats.png
│ │ ├── prometheus_vision_components.svg
│ │ ├── seongyun.jpeg
│ │ ├── seungone.jpeg
│ │ ├── suehyun.png
│ │ └── vlm_as_a_judge.svg
└── index.markdown
├── llava
├── __init__.py
├── constants.py
├── conversation.py
├── eval
│ ├── eval_gpt_review.py
│ ├── eval_gpt_review_bench.py
│ ├── eval_gpt_review_visual.py
│ ├── eval_pope.py
│ ├── eval_science_qa.py
│ ├── eval_science_qa_gpt4.py
│ ├── eval_science_qa_gpt4_requery.py
│ ├── eval_textvqa.py
│ ├── generate_webpage_data_from_table.py
│ ├── m4c_evaluator.py
│ ├── model_qa.py
│ ├── model_vqa.py
│ ├── model_vqa_loader.py
│ ├── model_vqa_mmbench.py
│ ├── model_vqa_qbench.py
│ ├── model_vqa_science.py
│ ├── qa_baseline_gpt35.py
│ ├── run_llava.py
│ ├── summarize_gpt_review.py
│ ├── table
│ │ ├── answer
│ │ │ ├── answer_alpaca-13b.jsonl
│ │ │ ├── answer_bard.jsonl
│ │ │ ├── answer_gpt35.jsonl
│ │ │ ├── answer_llama-13b.jsonl
│ │ │ └── answer_vicuna-13b.jsonl
│ │ ├── caps_boxes_coco2014_val_80.jsonl
│ │ ├── model.jsonl
│ │ ├── prompt.jsonl
│ │ ├── question.jsonl
│ │ ├── results
│ │ │ ├── test_sqa_llava_13b_v0.json
│ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json
│ │ ├── review
│ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl
│ │ │ ├── review_bard_vicuna-13b.jsonl
│ │ │ ├── review_gpt35_vicuna-13b.jsonl
│ │ │ └── review_llama-13b_vicuna-13b.jsonl
│ │ ├── reviewer.jsonl
│ │ └── rule.json
│ └── webpage
│ │ ├── figures
│ │ ├── alpaca.png
│ │ ├── bard.jpg
│ │ ├── chatgpt.svg
│ │ ├── llama.jpg
│ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg
│ │ └── vicuna.jpeg
│ │ ├── index.html
│ │ ├── script.js
│ │ └── styles.css
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── apply_delta.py
│ ├── builder.py
│ ├── consolidate.py
│ ├── language_model
│ │ ├── llava_llama.py
│ │ ├── llava_mpt.py
│ │ └── mpt
│ │ │ ├── adapt_tokenizer.py
│ │ │ ├── attention.py
│ │ │ ├── blocks.py
│ │ │ ├── configuration_mpt.py
│ │ │ ├── custom_embedding.py
│ │ │ ├── flash_attn_triton.py
│ │ │ ├── hf_prefixlm_converter.py
│ │ │ ├── meta_init_context.py
│ │ │ ├── modeling_mpt.py
│ │ │ ├── norm.py
│ │ │ └── param_init_fns.py
│ ├── llava_arch.py
│ ├── make_delta.py
│ ├── multimodal_encoder
│ │ ├── builder.py
│ │ └── clip_encoder.py
│ ├── multimodal_projector
│ │ └── builder.py
│ └── utils.py
├── serve
│ ├── __init__.py
│ ├── cli.py
│ ├── controller.py
│ ├── examples
│ │ ├── extreme_ironing.jpg
│ │ └── waterview.jpg
│ ├── gradio_web_server.py
│ ├── model_worker.py
│ ├── register_worker.py
│ └── test_message.py
├── train
│ ├── llama_flash_attn_monkey_patch.py
│ ├── llama_xformers_attn_monkey_patch.py
│ ├── llava_trainer.py
│ ├── train.py
│ ├── train_mem.py
│ └── train_xformers.py
└── utils.py
├── mmmu_download.py
├── pyproject.toml
├── sample_eval_data.jsonl
└── sample_train_data.json
/README.md:
--------------------------------------------------------------------------------
1 | # Prometheus-Vision
2 | [ACL 2024 Findings & ICLR 2024 WS] An Evaluator VLM that is open-source, offers reproducible evaluation, and inexpensive to use. Specifically designed for fine-grained evaluation on customized score rubric, Prometheus-Vision is a good alternative for human evaluation and GPT-4V evaluation.
3 |
4 | Our preprint, which describes our method in detail and provides full experimental results and analyses, can be found here:
5 | > [**Prometheus-Vision: Vision-Language Model as a Judge for Fine-Grained Evaluation**](https://arxiv.org/abs/2401.06591).
6 | > [Seongyun Lee](https://www.linkedin.com/in/seongyun-lee-647753233/)$^\ast$, [Seungone Kim](https://seungonekim.github.io/)$^\ast$, [Sue Hyun Park](https://suehyunpark.github.io/), [Geewook Kim](https://geewook.kim/), [Minjoon Seo](https://seominjoon.github.io/). Work in progress, arXiv preprint.
7 |
8 | ## Setup
9 | 1. Install package
10 | ```bash
11 | conda create -n prometheus-vision python=3.10 -y
12 | conda activate prometheus-vision
13 | pip install --upgrade pip # enable PEP 660 support
14 | pip install -e .
15 | pip install -e ".[train]"
16 | pip install flash-attn --no-build-isolation
17 | ```
18 | 2. Download data
19 | ```
20 | python mmmu_donwload.py
21 | wget http://images.cocodataset.org/zips/train2017.zip
22 | unzip train2017
23 | ```
24 | ## Input and Output Format of Prometheus-Vision
25 | Prometheus-Vision is trained and inferenced using the following input prompt format. Note that you could fill in the instruction, response, reference answer, and score rubrics with your own data.
26 | ```text
27 | ###Task Description:
28 | An instruction (might include an Input inside it), a response to evaluate, a reference answer that gets a score of 5, image and a score rubric representing an evaluation criterion is given.
29 | 1. Write a detailed feedback that assesses the quality of the response strictly based on the given score rubric, not evaluating in general.
30 | 2. After writing a feedback, write a score that is an integer between 1 and 5. You should refer to the score rubric.
31 | 3. The output format should look as follows: Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)
32 | 4. Please do not generate any other opening, closing, and explanations.
33 |
34 | ###The instruction to evaluate:
35 | {orig_instruction}
36 |
37 | ###Response to evaluate:
38 | {orig_response}
39 |
40 | ###Reference Answer (Score 5):
41 | {orig_reference_answer}
42 |
43 | ###Score Rubrics:
44 | [{orig_criteria}]
45 | Score 1: {orig_score1_description}
46 | Score 2: {orig_score2_description}
47 | Score 3: {orig_score3_description}
48 | Score 4: {orig_score4_description}
49 | Score 5: {orig_score5_description}
50 |
51 | ###Feedback:
52 | ```
53 | You can check the [Perception Collection](https://huggingface.co/datasets/kaist-ai/Perception-Collection) used for training Prometheus-Vision.
54 | ## Train
55 | We use [LLaVA](https://github.com/haotian-liu/LLaVA) codebase in developing Prometheus-Vision. Therefore, the following training & inference script is tailored to this.
56 |
57 | If you plan to start from a different VLM codebase, you should adapt the format of the data to suit your custom code.
58 |
59 | Note that you can also check the data format at [sample_train_data.json](https://github.com/kaistAI/prometheus-vision/blob/main/sample_train_data.json) and [sample_eval_data.json](https://github.com/kaistAI/prometheus-vision/blob/main/sample_train_data.json)
60 | ```bash
61 | deepspeed --include llava/train/train_mem.py \
62 | --deepspeed ./scripts/zero2.json \
63 | --model_name_or_path lmsys/vicuna-13b-v1.5 \
64 | --version plain \
65 | --data_path TRAINING_DATA_PATH \
66 | --vision_tower openai/clip-vit-large-patch14-336 \
67 | --mm_projector_type mlp2x_gelu \
68 | --tune_mm_mlp_adapter True \
69 | --mm_vision_select_layer -2 \
70 | --mm_use_im_start_end False \
71 | --mm_use_im_patch_token False \
72 | --bf16 True \
73 | --output_dir OUTPUT_DIR \
74 | --num_train_epochs 1 \
75 | --per_device_train_batch_size 32 \
76 | --per_device_eval_batch_size 4 \
77 | --gradient_accumulation_steps 1 \
78 | --evaluation_strategy "no" \
79 | --save_strategy "steps" \
80 | --save_steps 24000 \
81 | --save_total_limit 1 \
82 | --learning_rate 1e-3 \
83 | --weight_decay 0. \
84 | --warmup_ratio 0.03 \
85 | --lr_scheduler_type "cosine" \
86 | --logging_steps 1 \
87 | --tf32 True \
88 | --model_max_length 2048 \
89 | --gradient_checkpointing True \
90 | --dataloader_num_workers 4 \
91 | --lazy_preprocess True \
92 | --report_to wandb
93 | ```
94 | ## Inference
95 | You can place the path of your llava-style model in MODEL_PATH, or insert the trained [Prometheus-Vision](https://huggingface.co/kaist-ai/prometheus-vision-13b-v1.0) we provide.
96 | ```bash
97 | python -m llava.eval.model_vqa \
98 | --model-path MODEL_PATH \
99 | --question-file EVALUATION_DATA_PATH \
100 | --answers-file OUTPUT_PATH \
101 | --temperature 1.0 \
102 | --top_p 0.9 \
103 | --conv-mode vicuna_v1
104 | ```
105 | Additionally, you can use [Perception-Bench](https://huggingface.co/datasets/kaist-ai/Perception-Bench) when evaluating VLM.
106 |
107 | ## Citation
108 | If you find our work useful in your work, please consider citing our paper:
109 |
110 | ```
111 | @misc{lee2024prometheusvision,
112 | title={Prometheus-Vision: Vision-Language Model as a Judge for Fine-Grained Evaluation},
113 | author={Seongyun Lee and Seungone Kim and Sue Hyun Park and Geewook Kim and Minjoon Seo},
114 | year={2024},
115 | eprint={2401.06591},
116 | archivePrefix={arXiv},
117 | primaryClass={cs.CL}
118 | }
119 | ```
120 |
--------------------------------------------------------------------------------
/docs/.jekyll-cache/Jekyll/Cache/Jekyll--Cache/b7/9606fb3afea5bd1609ed40b622142f1c98125abcfe89a76a661b0e8e343910:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/.jekyll-cache/Jekyll/Cache/Jekyll--Cache/b7/9606fb3afea5bd1609ed40b622142f1c98125abcfe89a76a661b0e8e343910
--------------------------------------------------------------------------------
/docs/.jekyll-cache/Jekyll/Cache/Jekyll--Converters--Markdown/fd/aa4200cef2544ec568b25970b3a49cf9b4f3d7015485506ccda914caf9c5b9:
--------------------------------------------------------------------------------
1 | I"9
You’ll find this post in your _posts directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run jekyll serve, which launches a web server and auto-regenerates your site when a file is updated.
2 |
3 |
Jekyll requires blog post files to be named according to the following format:
4 |
5 |
YEAR-MONTH-DAY-title.MARKUP
6 |
7 |
Where YEAR is a four-digit number, MONTH and DAY are both two-digit numbers, and MARKUP is the file extension representing the format used in the file. After that, include the necessary front matter. Take a look at the source for this post to get an idea about how it works.
8 |
9 |
Jekyll also offers powerful support for code snippets:
10 |
11 |
defprint_hi(name)
12 | puts"Hi, #{name}"
13 | end
14 | print_hi('Tom')
15 | #=> prints 'Hi, Tom' to STDOUT.
16 |
17 |
Check out the Jekyll docs for more info on how to get the most out of Jekyll. File all bugs/feature requests at Jekyll’s GitHub repo. If you have questions, you can ask them on Jekyll Talk.
26 |
--------------------------------------------------------------------------------
/docs/Gemfile:
--------------------------------------------------------------------------------
1 | source "https://rubygems.org"
2 | # Hello! This is where you manage which Jekyll version is used to run.
3 | # When you want to use a different version, change it below, save the
4 | # file and run `bundle install`. Run Jekyll with `bundle exec`, like so:
5 | #
6 | # bundle exec jekyll serve
7 | #
8 | # This will help ensure the proper Jekyll version is running.
9 | # Happy Jekylling!
10 | gem "jekyll", "~> 4.2.0"
11 | # This is the default theme for new Jekyll sites. You may change this to anything you like.
12 | gem "minima", "~> 2.5"
13 | # If you want to use GitHub Pages, remove the "gem "jekyll"" above and
14 | # uncomment the line below. To upgrade, run `bundle update github-pages`.
15 | # gem "github-pages", group: :jekyll_plugins
16 | # If you have any plugins, put them here!
17 | group :jekyll_plugins do
18 | gem "jekyll-feed", "~> 0.12"
19 | end
20 |
21 | # Windows and JRuby does not include zoneinfo files, so bundle the tzinfo-data gem
22 | # and associated library.
23 | platforms :mingw, :x64_mingw, :mswin, :jruby do
24 | gem "tzinfo", "~> 1.2"
25 | gem "tzinfo-data"
26 | end
27 |
28 | # Performance-booster for watching directories on Windows
29 | gem "wdm", "~> 0.1.1", :platforms => [:mingw, :x64_mingw, :mswin]
30 |
31 |
32 | gem "webrick", "~> 1.8"
33 |
--------------------------------------------------------------------------------
/docs/Gemfile.lock:
--------------------------------------------------------------------------------
1 | GEM
2 | remote: https://rubygems.org/
3 | specs:
4 | addressable (2.7.0)
5 | public_suffix (>= 2.0.2, < 5.0)
6 | colorator (1.1.0)
7 | concurrent-ruby (1.1.7)
8 | em-websocket (0.5.2)
9 | eventmachine (>= 0.12.9)
10 | http_parser.rb (~> 0.6.0)
11 | eventmachine (1.2.7)
12 | ffi (1.14.2)
13 | forwardable-extended (2.6.0)
14 | http_parser.rb (0.6.0)
15 | i18n (1.8.7)
16 | concurrent-ruby (~> 1.0)
17 | jekyll (4.2.0)
18 | addressable (~> 2.4)
19 | colorator (~> 1.0)
20 | em-websocket (~> 0.5)
21 | i18n (~> 1.0)
22 | jekyll-sass-converter (~> 2.0)
23 | jekyll-watch (~> 2.0)
24 | kramdown (~> 2.3)
25 | kramdown-parser-gfm (~> 1.0)
26 | liquid (~> 4.0)
27 | mercenary (~> 0.4.0)
28 | pathutil (~> 0.9)
29 | rouge (~> 3.0)
30 | safe_yaml (~> 1.0)
31 | terminal-table (~> 2.0)
32 | jekyll-feed (0.15.1)
33 | jekyll (>= 3.7, < 5.0)
34 | jekyll-sass-converter (2.1.0)
35 | sassc (> 2.0.1, < 3.0)
36 | jekyll-seo-tag (2.7.1)
37 | jekyll (>= 3.8, < 5.0)
38 | jekyll-watch (2.2.1)
39 | listen (~> 3.0)
40 | kramdown (2.3.0)
41 | rexml
42 | kramdown-parser-gfm (1.1.0)
43 | kramdown (~> 2.0)
44 | liquid (4.0.3)
45 | listen (3.4.1)
46 | rb-fsevent (~> 0.10, >= 0.10.3)
47 | rb-inotify (~> 0.9, >= 0.9.10)
48 | mercenary (0.4.0)
49 | minima (2.5.1)
50 | jekyll (>= 3.5, < 5.0)
51 | jekyll-feed (~> 0.9)
52 | jekyll-seo-tag (~> 2.1)
53 | pathutil (0.16.2)
54 | forwardable-extended (~> 2.6)
55 | public_suffix (4.0.6)
56 | rb-fsevent (0.10.4)
57 | rb-inotify (0.10.1)
58 | ffi (~> 1.0)
59 | rexml (3.2.4)
60 | rouge (3.26.0)
61 | safe_yaml (1.0.5)
62 | sassc (2.4.0)
63 | ffi (~> 1.9)
64 | terminal-table (2.0.0)
65 | unicode-display_width (~> 1.1, >= 1.1.1)
66 | unicode-display_width (1.7.0)
67 | webrick (1.8.1)
68 |
69 | PLATFORMS
70 | arm64-darwin-22
71 | universal-darwin-20
72 |
73 | DEPENDENCIES
74 | jekyll (~> 4.2.0)
75 | jekyll-feed (~> 0.12)
76 | minima (~> 2.5)
77 | tzinfo (~> 1.2)
78 | tzinfo-data
79 | wdm (~> 0.1.1)
80 | webrick (~> 1.8)
81 |
82 | BUNDLED WITH
83 | 2.2.5
84 |
--------------------------------------------------------------------------------
/docs/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 tsook
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | # Welcome to Jekyll!
2 | #
3 | # This config file is meant for settings that affect your whole blog, values
4 | # which you are expected to set up once and rarely edit after that. If you find
5 | # yourself editing this file very often, consider using Jekyll's data files
6 | # feature for the data you need to update frequently.
7 | #
8 | # For technical reasons, this file is *NOT* reloaded automatically when you use
9 | # 'bundle exec jekyll serve'. If you change this file, please restart the server process.
10 | #
11 | # If you need help with YAML syntax, here are some quick references for you:
12 | # https://learn-thse-web.algonquindesign.ca/topics/markdown-yaml-cheat-sheet/#yaml
13 | # https://learnxinyminutes.com/docs/yaml/
14 | #
15 | # Site settings
16 | # These are used to personalize your new site. If you look in the HTML files,
17 | # you will see them accessed via {{ site.title }}, {{ site.email }}, and so on.
18 | # You can create any custom variable you would like, and they will be accessible
19 | # in the templates via {{ site.myvariable }}.
20 |
21 | title: Prometheus-Vision
22 | theme: minima
23 | description: "Vision-Language Model as a Judge for Fine-Grained Evaluation"
24 |
25 | syscolor: "#555555"
26 | sysfont: Inter
27 | paper: https://arxiv.org/abs/2401.06591
28 | code: https://github.com/kaistAI/prometheus-vision
29 | hf_train_data: https://huggingface.co/datasets/kaist-ai/Perception-Collection
30 | hf_test_data: https://huggingface.co/datasets/kaist-ai/Perception-Bench
31 | hf_data: https://huggingface.co/datasets/kaist-ai/Perception-Collection
32 | hf_model_7b: https://huggingface.co/kaist-ai/prometheus-vision-7b-v1.0
33 | hf_model_13b: https://huggingface.co/kaist-ai/prometheus-vision-13b-v1.0
34 |
35 | plugins:
36 | - jekyll-feed
37 |
38 | # Exclude from processing.
39 | # The following items will not be processed, by default.
40 | # Any item listed under the `exclude:` key here will be automatically added to
41 | # the internal "default list".
42 | #
43 | # Excluded items can be processed by explicitly listing the directories or
44 | # their entries' file path in the `include:` list.
45 | #
46 | # exclude:
47 | # - .sass-cache/
48 | # - .jekyll-cache/
49 | # - gemfiles/
50 | # - Gemfile
51 | # - Gemfile.lock
52 | # - node_modules/
53 | # - vendor/bundle/
54 | # - vendor/cache/
55 | # - vendor/gems/
56 | # - vendor/ruby/
57 |
--------------------------------------------------------------------------------
/docs/_data/authors.yml:
--------------------------------------------------------------------------------
1 | - name: Seongyun Lee*
2 | affiliation: KAIST AI
3 | website: https://github.com/sylee0520
4 | img: seongyun.jpeg
5 | - name: Seungone Kim*
6 | affiliation: KAIST AI, NAVER AI Lab
7 | website: https://seungonekim.github.io/
8 | img: seungone.jpeg
9 | - name: Sue Hyun Park
10 | affiliation: KAIST AI
11 | website: https://suehyunpark.github.io/
12 | img: suehyun.png
13 | - name: Geewook Kim
14 | affiliation: KAIST AI, NAVER Cloud AI
15 | website: https://geewook.kim/
16 | img: geewook.jpg
17 | - name: Minjoon Seo
18 | affiliation: KAIST AI
19 | website: https://seominjoon.github.io/
20 | img: minjoon.jpeg
21 |
--------------------------------------------------------------------------------
/docs/_posts/2021-01-16-welcome-to-jekyll.markdown:
--------------------------------------------------------------------------------
1 | ---
2 | layout: post
3 | title: "Welcome to Jekyll!"
4 | date: 2021-01-16 20:44:40 +0900
5 | categories: jekyll update
6 | ---
7 | You’ll find this post in your `_posts` directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run `jekyll serve`, which launches a web server and auto-regenerates your site when a file is updated.
8 |
9 | Jekyll requires blog post files to be named according to the following format:
10 |
11 | `YEAR-MONTH-DAY-title.MARKUP`
12 |
13 | Where `YEAR` is a four-digit number, `MONTH` and `DAY` are both two-digit numbers, and `MARKUP` is the file extension representing the format used in the file. After that, include the necessary front matter. Take a look at the source for this post to get an idea about how it works.
14 |
15 | Jekyll also offers powerful support for code snippets:
16 |
17 | {% highlight ruby %}
18 | def print_hi(name)
19 | puts "Hi, #{name}"
20 | end
21 | print_hi('Tom')
22 | #=> prints 'Hi, Tom' to STDOUT.
23 | {% endhighlight %}
24 |
25 | Check out the [Jekyll docs][jekyll-docs] for more info on how to get the most out of Jekyll. File all bugs/feature requests at [Jekyll’s GitHub repo][jekyll-gh]. If you have questions, you can ask them on [Jekyll Talk][jekyll-talk].
26 |
27 | [jekyll-docs]: https://jekyllrb.com/docs/home
28 | [jekyll-gh]: https://github.com/jekyll/jekyll
29 | [jekyll-talk]: https://talk.jekyllrb.com/
30 |
--------------------------------------------------------------------------------
/docs/_sass/main.scss:
--------------------------------------------------------------------------------
1 | body {
2 | display: flex;
3 | flex-direction: row;
4 | justify-content: center;
5 | }
6 |
7 | .whitespace {
8 | flex: 1;
9 | }
10 |
11 | .wrapper {
12 | flex: 2 2 50%;
13 | }
14 |
--------------------------------------------------------------------------------
/docs/_site/404.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 | EvalLM
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
24 |
25 |
26 |
27 |
28 |
29 |
30 | EvalLM
31 |
32 |
Interactive Evaluation of Large Language Model Prompts on User-Defined Criteria
172 |
181 |
182 |
183 |
184 |
--------------------------------------------------------------------------------
/docs/_site/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 tsook
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docs/_site/assets/favicon/android-chrome-192x192.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/favicon/android-chrome-192x192.png
--------------------------------------------------------------------------------
/docs/_site/assets/favicon/android-chrome-512x512.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/favicon/android-chrome-512x512.png
--------------------------------------------------------------------------------
/docs/_site/assets/favicon/apple-touch-icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/favicon/apple-touch-icon.png
--------------------------------------------------------------------------------
/docs/_site/assets/favicon/favicon-16x16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/favicon/favicon-16x16.png
--------------------------------------------------------------------------------
/docs/_site/assets/favicon/favicon-32x32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/favicon/favicon-32x32.png
--------------------------------------------------------------------------------
/docs/_site/assets/favicon/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/favicon/favicon.ico
--------------------------------------------------------------------------------
/docs/_site/assets/favicon/site.webmanifest:
--------------------------------------------------------------------------------
1 | {"name":"","short_name":"","icons":[{"src":"/android-chrome-192x192.png","sizes":"192x192","type":"image/png"},{"src":"/android-chrome-512x512.png","sizes":"512x512","type":"image/png"}],"theme_color":"#ffffff","background_color":"#ffffff","display":"standalone"}
--------------------------------------------------------------------------------
/docs/_site/assets/img/animation.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/img/animation.gif
--------------------------------------------------------------------------------
/docs/_site/assets/img/criteria.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/img/criteria.png
--------------------------------------------------------------------------------
/docs/_site/assets/img/datarow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/img/datarow.png
--------------------------------------------------------------------------------
/docs/_site/assets/img/interface.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/img/interface.png
--------------------------------------------------------------------------------
/docs/_site/assets/img/jamin.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/img/jamin.jpg
--------------------------------------------------------------------------------
/docs/_site/assets/img/juho.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/img/juho.jpg
--------------------------------------------------------------------------------
/docs/_site/assets/img/kaist_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/img/kaist_logo.png
--------------------------------------------------------------------------------
/docs/_site/assets/img/kixlab_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/img/kixlab_logo.png
--------------------------------------------------------------------------------
/docs/_site/assets/img/naver_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/img/naver_logo.png
--------------------------------------------------------------------------------
/docs/_site/assets/img/taesoo.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/img/taesoo.jpeg
--------------------------------------------------------------------------------
/docs/_site/assets/img/yoonjoo.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/img/yoonjoo.jpeg
--------------------------------------------------------------------------------
/docs/_site/assets/img/youngho.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/_site/assets/img/youngho.jpg
--------------------------------------------------------------------------------
/docs/_site/feed.xml:
--------------------------------------------------------------------------------
1 | Jekyll2023-10-11T15:51:38+09:00/feed.xmlEvalLMInteractive Evaluation of Large Language Model Prompts on User-Defined CriteriaWelcome to Jekyll!2021-01-16T20:44:40+09:002021-01-16T20:44:40+09:00/jekyll/update/2021/01/16/welcome-to-jekyll<p>You’ll find this post in your <code class="language-plaintext highlighter-rouge">_posts</code> directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run <code class="language-plaintext highlighter-rouge">jekyll serve</code>, which launches a web server and auto-regenerates your site when a file is updated.</p>
2 |
3 | <p>Jekyll requires blog post files to be named according to the following format:</p>
4 |
5 | <p><code class="language-plaintext highlighter-rouge">YEAR-MONTH-DAY-title.MARKUP</code></p>
6 |
7 | <p>Where <code class="language-plaintext highlighter-rouge">YEAR</code> is a four-digit number, <code class="language-plaintext highlighter-rouge">MONTH</code> and <code class="language-plaintext highlighter-rouge">DAY</code> are both two-digit numbers, and <code class="language-plaintext highlighter-rouge">MARKUP</code> is the file extension representing the format used in the file. After that, include the necessary front matter. Take a look at the source for this post to get an idea about how it works.</p>
8 |
9 | <p>Jekyll also offers powerful support for code snippets:</p>
10 |
11 | <figure class="highlight"><pre><code class="language-ruby" data-lang="ruby"><span class="k">def</span> <span class="nf">print_hi</span><span class="p">(</span><span class="nb">name</span><span class="p">)</span>
12 | <span class="nb">puts</span> <span class="s2">"Hi, </span><span class="si">#{</span><span class="nb">name</span><span class="si">}</span><span class="s2">"</span>
13 | <span class="k">end</span>
14 | <span class="n">print_hi</span><span class="p">(</span><span class="s1">'Tom'</span><span class="p">)</span>
15 | <span class="c1">#=> prints 'Hi, Tom' to STDOUT.</span></code></pre></figure>
16 |
17 | <p>Check out the <a href="https://jekyllrb.com/docs/home">Jekyll docs</a> for more info on how to get the most out of Jekyll. File all bugs/feature requests at <a href="https://github.com/jekyll/jekyll">Jekyll’s GitHub repo</a>. If you have questions, you can ask them on <a href="https://talk.jekyllrb.com/">Jekyll Talk</a>.</p>You’ll find this post in your _posts directory. Go ahead and edit it and re-build the site to see your changes. You can rebuild the site in many different ways, but the most common way is to run jekyll serve, which launches a web server and auto-regenerates your site when a file is updated.
--------------------------------------------------------------------------------
/docs/about.markdown:
--------------------------------------------------------------------------------
1 | ---
2 | layout: page
3 | title: About
4 | permalink: /about/
5 | ---
6 |
7 | This is the base Jekyll theme. You can find out more info about customizing your Jekyll theme, as well as basic Jekyll usage documentation at [jekyllrb.com](https://jekyllrb.com/)
8 |
9 | You can find the source code for Minima at GitHub:
10 | [jekyll][jekyll-organization] /
11 | [minima](https://github.com/jekyll/minima)
12 |
13 | You can find the source code for Jekyll at GitHub:
14 | [jekyll][jekyll-organization] /
15 | [jekyll](https://github.com/jekyll/jekyll)
16 |
17 |
18 | [jekyll-organization]: https://github.com/jekyll
19 |
--------------------------------------------------------------------------------
/docs/assets/css/styles.scss:
--------------------------------------------------------------------------------
1 | ---
2 | ---
3 |
4 | @import "{{ site.theme }}";
5 |
6 | html {
7 | scroll-behavior:smooth
8 | }
9 |
10 | body {
11 | font-family: 'Noto Sans', sans-serif;
12 | }
13 |
14 | .wrapper {
15 | max-width: calc(940px - (30px * 2));
16 | }
17 |
18 | h1 {
19 | font-weight: bold;
20 | font-size: 40px;
21 | margin-bottom: 0px;
22 | margin-top: 20px;
23 | }
24 |
25 | hr {
26 | margin-top: 32px;
27 | margin-bottom: 32px;
28 | background-color: #ccc;
29 | border: 0px;
30 | height: 1px;
31 | }
32 |
33 | h4 {
34 | color: #666;
35 | font-size: 28px;
36 | margin-bottom: 20px;
37 | }
38 |
39 | h2 {
40 | font-size: 24px;
41 | margin-bottom: 20px;
42 | color: #333;
43 | }
44 |
45 | h3 {
46 | font-size: 20px;
47 | padding-top: 12px;
48 | color: #333;
49 | }
50 |
51 | p {
52 | font-size: 18px;
53 | margin-bottom: 20px;
54 | color: #555;
55 | font-weight: 300;
56 | & > b {
57 | font-weight: 500;
58 | }
59 | }
60 |
61 | .video-wrapper {
62 | text-align: center;
63 | margin-bottom: 20px;
64 | position: relative;
65 | width: 100%;
66 | padding-bottom: 56.25%;
67 | }
68 |
69 | iframe {
70 | position: absolute;
71 | width: 100%;
72 | height: 100%;
73 | left: 0;
74 | top: 0;
75 | border-radius: 16px;
76 | box-shadow: rgba(0, 0, 0, 0.16) 0px 3px 6px, rgba(0, 0, 0, 0.23) 0px 3px 6px;
77 | }
78 |
79 | .authors-wrapper {
80 | text-align: center;
81 | }
82 |
83 | .author-container {
84 | display: inline-block;
85 | width: 12%;
86 | margin: 0 2% 0 2%;
87 | }
88 |
89 | .author-container p {
90 | margin-bottom: 0px;
91 | font-size: 14px;
92 | }
93 |
94 | .author-image {
95 | position: relative;
96 | width: 100%;
97 | padding-bottom: 100%;
98 | }
99 |
100 | .author-image img {
101 | position: absolute;
102 | width: 100%;
103 | height: 100%;
104 | top: 0;
105 | left: 0;
106 | border-radius: 50%;
107 | }
108 |
109 | .center {
110 | text-align: center;
111 | }
112 |
113 | .acknowledgement {
114 | color: #999;
115 | font-size: 16px;
116 | margin-bottom: 0px;
117 | }
118 |
119 | .logos {
120 | text-align: center;
121 | }
122 |
123 | .logos img {
124 | padding: 0 4px 0 4px;
125 | height: 24px;
126 | }
127 |
128 | .img-left {
129 | display: inline-block;
130 | width: 45%;
131 | }
132 |
133 | .text-right {
134 | display: inline-block;
135 | width: calc(54% - 16px);
136 | margin-left: 16px;
137 | vertical-align: middle;
138 | }
139 |
140 | .img-right {
141 | display: inline-block;
142 | width: 49%;
143 | }
144 |
145 | .text-left {
146 | display: inline-block;
147 | width: calc(50% - 16px);
148 | margin-right: 16px;
149 | vertical-align: middle;
150 | }
151 |
152 | .quote {
153 | width: 70%;
154 | margin-left: 15%
155 | }
156 |
157 | pre {
158 | display: block;
159 | font-size: 14px;
160 | color: #333;
161 | word-break: break-all;
162 | word-wrap: break-word;
163 | background-color: #f5f5f5;
164 | border: 1px solid #ccc;
165 | border-radius: 4px;
166 | white-space: pre-wrap; /* Since CSS 2.1 */
167 | white-space: -moz-pre-wrap; /* Mozilla, since 1999 */
168 | white-space: -pre-wrap; /* Opera 4-6 */
169 | white-space: -o-pre-wrap; /* Opera 7 */
170 | word-wrap: break-word; /* Internet Explorer 5.5+ */
171 | }
172 |
173 | .sys-img {
174 | width: 90%;
175 | margin-left: 5%;
176 | }
177 |
178 |
179 | .button-container {
180 | display: flex;
181 | justify-content: center;
182 | align-items: center;
183 | gap: 16px;
184 | margin-top: 16px;
185 | }
186 |
187 | .button {
188 | display: flex;
189 | justify-content: center;
190 | align-items: center;
191 | gap: 8px;
192 | padding: 8px 16px;
193 | border-radius: 100px;
194 | background-color: #363636;
195 | color: #fff;
196 | fill: #fff;
197 | font-size: 16px;
198 | // width: 80px;
199 | cursor: pointer;
200 |
201 | &:visited {
202 | color: #fff;
203 | }
204 | &:active {
205 | color: #fff;
206 | }
207 | &:hover {
208 | background-color: #222222;
209 | color: #fff;
210 | text-decoration: none;
211 | }
212 | }
213 |
214 | .button-disabled {
215 | display: flex;
216 | justify-content: center;
217 | align-items: center;
218 | gap: 8px;
219 | padding: 8px 16px;
220 | border-radius: 100px;
221 | background-color: #787878;
222 | color: #fff;
223 | fill: #fff;
224 | font-size: 16px;
225 | width: 120px;
226 | cursor: default;
227 |
228 | &:visited {
229 | color: #fff;
230 | }
231 | &:active {
232 | color: #fff;
233 | }
234 |
235 | &:hover {
236 | text-decoration: none;
237 | color: #fff;
238 | }
239 | }
240 |
241 | .sys-name {
242 | font-variant: small-caps;
243 | // font-weight: bold;
244 | color: {{site.syscolor}};
245 | }
246 |
247 | strong {
248 | font-weight: 500;
249 | }
250 |
251 | .footer {
252 | width: 100%;
253 | background-color: #f5f5f5;
254 | margin-top: 32px;
255 |
256 | & > div {
257 | max-width: calc(940px - (30px * 2));
258 | margin-left: auto;
259 | margin-right: auto;
260 | padding: 16px 0;
261 | }
262 | }
263 |
264 | .credits {
265 | color: #777;
266 | font-size: 14px;
267 | margin-bottom: 0px;
268 | }
269 |
270 | @media (max-width: 600px) {
271 | h1 {
272 | font-size: 36px;
273 | }
274 |
275 | h4 {
276 | font-size: 20px;
277 | }
278 |
279 | h2 {
280 | font-size: 20px;
281 | }
282 |
283 | h3 {
284 | font-size: 16px;
285 | }
286 |
287 | p {
288 | font-size: 14px;
289 | }
290 |
291 | .author-container p {
292 | font-size: 10px;
293 | }
294 |
295 | pre {
296 | font-size: 10px;
297 | }
298 |
299 | .wrapper {
300 | max-width: calc(100vw - 30px);
301 | }
302 |
303 | .author-container {
304 | display: inline-block;
305 | width: 22%;
306 | margin: 0 1% 0 1%;
307 | }
308 |
309 | .img-left {
310 | display: block;
311 | text-align: center;
312 | width: 70%;
313 | margin-left: 15%;
314 | }
315 |
316 | .text-right {
317 | width: 100%;
318 | margin-left: 0px;
319 | }
320 |
321 | .sys-img {
322 | width: 100%;
323 | margin-left: 0;
324 | }
325 |
326 | .acknowledgement {
327 | font-size: 12px;
328 | }
329 |
330 | .credits {
331 | font-size: 10px;
332 | }
333 |
334 | .footer {
335 | & > div {
336 | max-width: calc(100vw - 30px);
337 | }
338 | }
339 | }
340 |
--------------------------------------------------------------------------------
/docs/assets/favicon/android-chrome-192x192.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/favicon/android-chrome-192x192.png
--------------------------------------------------------------------------------
/docs/assets/favicon/android-chrome-512x512.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/favicon/android-chrome-512x512.png
--------------------------------------------------------------------------------
/docs/assets/favicon/apple-touch-icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/favicon/apple-touch-icon.png
--------------------------------------------------------------------------------
/docs/assets/favicon/favicon-16x16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/favicon/favicon-16x16.png
--------------------------------------------------------------------------------
/docs/assets/favicon/favicon-32x32.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/favicon/favicon-32x32.png
--------------------------------------------------------------------------------
/docs/assets/favicon/favicon.ico:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/favicon/favicon.ico
--------------------------------------------------------------------------------
/docs/assets/favicon/site.webmanifest:
--------------------------------------------------------------------------------
1 | {"name":"","short_name":"","icons":[{"src":"/android-chrome-192x192.png","sizes":"192x192","type":"image/png"},{"src":"/android-chrome-512x512.png","sizes":"512x512","type":"image/png"}],"theme_color":"#ffffff","background_color":"#ffffff","display":"standalone"}
--------------------------------------------------------------------------------
/docs/assets/img/blue_fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/img/blue_fire.png
--------------------------------------------------------------------------------
/docs/assets/img/geewook.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/img/geewook.jpg
--------------------------------------------------------------------------------
/docs/assets/img/instructionfollowing_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/img/instructionfollowing_results.png
--------------------------------------------------------------------------------
/docs/assets/img/kaist_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/img/kaist_logo.png
--------------------------------------------------------------------------------
/docs/assets/img/lklab_logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/img/lklab_logo.jpg
--------------------------------------------------------------------------------
/docs/assets/img/minjoon.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/img/minjoon.jpeg
--------------------------------------------------------------------------------
/docs/assets/img/naver_ai_lab_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/img/naver_ai_lab_logo.png
--------------------------------------------------------------------------------
/docs/assets/img/naver_cloud_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/img/naver_cloud_logo.png
--------------------------------------------------------------------------------
/docs/assets/img/naver_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/img/naver_logo.png
--------------------------------------------------------------------------------
/docs/assets/img/perception_collection_stats.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/img/perception_collection_stats.png
--------------------------------------------------------------------------------
/docs/assets/img/seongyun.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/img/seongyun.jpeg
--------------------------------------------------------------------------------
/docs/assets/img/seungone.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/img/seungone.jpeg
--------------------------------------------------------------------------------
/docs/assets/img/suehyun.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/docs/assets/img/suehyun.png
--------------------------------------------------------------------------------
/docs/index.markdown:
--------------------------------------------------------------------------------
1 | ---
2 | layout: default
3 | ---
4 |
5 | We introduce Prometheus-Vision, an evaluator Vision-Language Model (VLM) that is **open-source**, offers **reproducible** evaluation, and is **inexpensive** to use. We construct Perception-Collection, the first multimodal feedback dataset for training an evaluator VLM, which includes 15K **fine-grained scoring criteria** defined for each instance. Prometheus-Vision trained on Perception-Collection shows high correlation with human evaluators and GPT-4V, paving the way for accessible and transparent evaluation of VLMs.
6 |
7 | ------
8 |
9 |
10 |
11 | ## VLM-as-a-Judge for Fine-Grained Evaluation
12 |
13 | Recent VLMs exhibit impressive visual instruction-following capabilities. To assess and compare the quality of VLM-generated outputs, we utilize VLMs as evaluator of VLMs, naming the approach as 'VLM-as-a-Judge'.
14 |
15 | {: .sys-img}
16 | 
17 |
18 | Traditional metrics for VLM evaluation measure the similarity between a response and the ground-truth answer. However, such automatic metrics fail to capture the rich context within the output. Also, these metrics do not explain what is missing or present in the response.
19 |
20 | An evaluator VLM can adhere to specific criteria of interest to focus on nuanced details in the visual context and instruction. Moreover, it can provide detailed language feedback that helps the user understand the reasoning behind the scoring.
21 |
22 |
23 |
24 | ## Multimodal Feedback Data
25 |
26 | The [Perception-Collection](https://huggingface.co/datasets/kaist-ai/Perception-Collection) dataset is targeted for fine-grained multimodal feedback generation. Each instance consists of 5 input components: an instruction, a real-world image, a response to evaluate, a customized score rubric, and a reference answer. Based on this, an evaluator VLM is trained to generate a language feedback and a score decision on a scale of 1 to 5.
27 |
28 | {: .sys-img}
29 | 
30 |
31 | We collect 5K real-world images sampled from the [COCO dataset](https://cocodataset.org/#home) and the [MMMU benchmark](https://arxiv.org/abs/2311.16502). Then, we augment the data in a 4-stage process: (1) hand-craft 50 seed score rubrics, (2) brainstorm and refine 15K fine-grained score rubrics, (3) augment 30K instructions and reference answers related to the score rubric, and (4) augment 150K responses and language feedback for training. From stage 2 to 4, we prompt GPT-4V to generate the data. We ensure that each generated score rubric aligns with the image and that there is no length bias in responses across the score range.
32 |
33 | {: .sys-img}
34 | 
35 |
36 | We also release a held-out test set of the Perception-Collection called [Perception-Bench](https://huggingface.co/datasets/kaist-ai/Perception-Bench), which contains 500 instances and a single score rubric for each instance.
37 |
38 |
39 |
40 | ## Performance of Prometheus-Vision
41 |
42 | Using the Perception-Collection, we use [LLaVA-1.5](https://arxiv.org/abs/2310.03744) (7B & 13B) as our backbone model and train Prometheus-Vision ([7B](https://huggingface.co/kaist-ai/prometheus-vision-7b-v1.0) & [13B](https://huggingface.co/kaist-ai/prometheus-vision-13b-v1.0)). Through experiments, we demonstrate that Prometheus-Vision can be an effective open-source alternative to using human or GPT-4V for VLM evaluation.
43 |
44 |
45 | ### Simulating Human Evaluators
46 |
47 | Prometheus-Vision shows high correlation with human evaluators on instances with real-world images, LLaVA-Bench and Perception-Bench. Also, Prometheus-Vision 13B's feedback is as good as or better than GPT-4V's feedback 57.78% of the time.
48 |
49 | {: .img-left}
50 | 
51 |
52 | {: .img-right}
53 | 
54 |
55 |
56 | ### Simulating GPT-4V
57 |
58 | Prometheus-Vision shows the highest correlation with GPT-4V among open-source VLMs and outperforms GPT-3.5-Turbo and GPT-4 ('LM-as-a-Judge') in LLaVA-Bench and Perception-Bench.
59 |
60 | {: .sys-img}
61 | 
62 |
63 | ------
64 |
65 | ## Bibtex
66 | If you find our work useful in your work, please consider citing our paper:
67 |
68 |
69 | @misc{lee2024prometheusvision,
70 | title={Prometheus-Vision: Vision-Language Model as a Judge for Fine-Grained Evaluation},
71 | author={Seongyun Lee and Seungone Kim and Sue Hyun Park and Geewook Kim and Minjoon Seo},
72 | year={2024},
73 | eprint={2401.06591},
74 | archivePrefix={arXiv},
75 | primaryClass={cs.CL}
76 | }
77 |
78 |
79 | ------
80 |
81 | {: .logos}
82 | [](https://kaist.ac.kr)
83 | [](https://lklab.kaist.ac.kr/)
84 | [](https://www.facebook.com/NAVERAILAB)
85 | [](https://www.navercloudcorp.com/lang/en/)
86 |
87 |
89 |
--------------------------------------------------------------------------------
/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/llava/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 | IMAGE_PLACEHOLDER = ""
14 |
--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import tqdm
7 | import ray
8 | import time
9 |
10 | NUM_SECONDS_TO_SLEEP = 3
11 |
12 | @ray.remote(num_cpus=4)
13 | def get_eval(content: str, max_tokens: int):
14 | while True:
15 | try:
16 | response = openai.ChatCompletion.create(
17 | model='gpt-4',
18 | messages=[{
19 | 'role': 'system',
20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
21 | }, {
22 | 'role': 'user',
23 | 'content': content,
24 | }],
25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
26 | max_tokens=max_tokens,
27 | )
28 | break
29 | except openai.error.RateLimitError:
30 | pass
31 | except Exception as e:
32 | print(e)
33 | time.sleep(NUM_SECONDS_TO_SLEEP)
34 |
35 | print('success!')
36 | return response['choices'][0]['message']['content']
37 |
38 |
39 | def parse_score(review):
40 | try:
41 | score_pair = review.split('\n')[0]
42 | score_pair = score_pair.replace(',', ' ')
43 | sp = score_pair.split(' ')
44 | if len(sp) == 2:
45 | return [float(sp[0]), float(sp[1])]
46 | else:
47 | print('error', review)
48 | return [-1, -1]
49 | except Exception as e:
50 | print(e)
51 | print('error', review)
52 | return [-1, -1]
53 |
54 |
55 | if __name__ == '__main__':
56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
57 | parser.add_argument('-q', '--question')
58 | # parser.add_argument('-a', '--answer')
59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
60 | parser.add_argument('-r', '--rule')
61 | parser.add_argument('-o', '--output')
62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
63 | args = parser.parse_args()
64 |
65 | ray.init()
66 |
67 | f_q = open(os.path.expanduser(args.question))
68 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
69 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
71 |
72 | review_file = open(f'{args.output}', 'w')
73 |
74 | js_list = []
75 | handles = []
76 | idx = 0
77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
78 | # if idx == 1:
79 | # break
80 |
81 | ques = json.loads(ques_js)
82 | ans1 = json.loads(ans1_js)
83 | ans2 = json.loads(ans2_js)
84 |
85 | category = json.loads(ques_js)['category']
86 | if category in rule_dict:
87 | rule = rule_dict[category]
88 | else:
89 | rule = rule_dict['default']
90 | prompt = rule['prompt']
91 | role = rule['role']
92 | content = (f'[Question]\n{ques["text"]}\n\n'
93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
95 | f'[System]\n{prompt}\n\n')
96 | js_list.append({
97 | 'id': idx+1,
98 | 'question_id': ques['question_id'],
99 | 'answer1_id': ans1['answer_id'],
100 | 'answer2_id': ans2['answer_id'],
101 | 'category': category})
102 | idx += 1
103 | handles.append(get_eval.remote(content, args.max_tokens))
104 | # To avoid the rate limit set by OpenAI
105 | time.sleep(NUM_SECONDS_TO_SLEEP)
106 |
107 | reviews = ray.get(handles)
108 | for idx, review in enumerate(reviews):
109 | scores = parse_score(review)
110 | js_list[idx]['content'] = review
111 | js_list[idx]['tuple'] = scores
112 | review_file.write(json.dumps(js_list[idx]) + '\n')
113 | review_file.close()
114 |
--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review_bench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import time
7 |
8 | NUM_SECONDS_TO_SLEEP = 0.5
9 |
10 |
11 | def get_eval(content: str, max_tokens: int):
12 | while True:
13 | try:
14 | response = openai.ChatCompletion.create(
15 | model='gpt-4-0314',
16 | messages=[{
17 | 'role': 'system',
18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19 | }, {
20 | 'role': 'user',
21 | 'content': content,
22 | }],
23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
24 | max_tokens=max_tokens,
25 | )
26 | break
27 | except openai.error.RateLimitError:
28 | pass
29 | except Exception as e:
30 | print(e)
31 | time.sleep(NUM_SECONDS_TO_SLEEP)
32 |
33 | return response['choices'][0]['message']['content']
34 |
35 |
36 | def parse_score(review):
37 | try:
38 | score_pair = review.split('\n')[0]
39 | score_pair = score_pair.replace(',', ' ')
40 | sp = score_pair.split(' ')
41 | if len(sp) == 2:
42 | return [float(sp[0]), float(sp[1])]
43 | else:
44 | print('error', review)
45 | return [-1, -1]
46 | except Exception as e:
47 | print(e)
48 | print('error', review)
49 | return [-1, -1]
50 |
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54 | parser.add_argument('-q', '--question')
55 | parser.add_argument('-c', '--context')
56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57 | parser.add_argument('-r', '--rule')
58 | parser.add_argument('-o', '--output')
59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60 | args = parser.parse_args()
61 |
62 | f_q = open(os.path.expanduser(args.question))
63 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66 |
67 | if os.path.isfile(os.path.expanduser(args.output)):
68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69 | else:
70 | cur_reviews = []
71 |
72 | review_file = open(f'{args.output}', 'a')
73 |
74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75 | image_to_context = {context['image']: context for context in context_list}
76 |
77 | handles = []
78 | idx = 0
79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80 | ques = json.loads(ques_js)
81 | ans1 = json.loads(ans1_js)
82 | ans2 = json.loads(ans2_js)
83 |
84 | inst = image_to_context[ques['image']]
85 |
86 | if isinstance(inst['caption'], list):
87 | cap_str = '\n'.join(inst['caption'])
88 | else:
89 | cap_str = inst['caption']
90 |
91 | category = 'llava_bench_' + json.loads(ques_js)['category']
92 | if category in rule_dict:
93 | rule = rule_dict[category]
94 | else:
95 | assert False, f"Visual QA category not found in rule file: {category}."
96 | prompt = rule['prompt']
97 | role = rule['role']
98 | content = (f'[Context]\n{cap_str}\n\n'
99 | f'[Question]\n{ques["text"]}\n\n'
100 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
101 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
102 | f'[System]\n{prompt}\n\n')
103 | cur_js = {
104 | 'id': idx+1,
105 | 'question_id': ques['question_id'],
106 | 'answer1_id': ans1.get('answer_id', ans1['question_id']),
107 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
108 | 'category': category
109 | }
110 | if idx >= len(cur_reviews):
111 | review = get_eval(content, args.max_tokens)
112 | scores = parse_score(review)
113 | cur_js['content'] = review
114 | cur_js['tuple'] = scores
115 | review_file.write(json.dumps(cur_js) + '\n')
116 | review_file.flush()
117 | else:
118 | print(f'Skipping {idx} as we already have it.')
119 | idx += 1
120 | print(idx)
121 | review_file.close()
122 |
--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review_visual.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import time
7 |
8 | NUM_SECONDS_TO_SLEEP = 0.5
9 |
10 |
11 | def get_eval(content: str, max_tokens: int):
12 | while True:
13 | try:
14 | response = openai.ChatCompletion.create(
15 | model='gpt-4-0314',
16 | messages=[{
17 | 'role': 'system',
18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19 | }, {
20 | 'role': 'user',
21 | 'content': content,
22 | }],
23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
24 | max_tokens=max_tokens,
25 | )
26 | break
27 | except openai.error.RateLimitError:
28 | pass
29 | except Exception as e:
30 | print(e)
31 | time.sleep(NUM_SECONDS_TO_SLEEP)
32 |
33 | return response['choices'][0]['message']['content']
34 |
35 |
36 | def parse_score(review):
37 | try:
38 | score_pair = review.split('\n')[0]
39 | score_pair = score_pair.replace(',', ' ')
40 | sp = score_pair.split(' ')
41 | if len(sp) == 2:
42 | return [float(sp[0]), float(sp[1])]
43 | else:
44 | print('error', review)
45 | return [-1, -1]
46 | except Exception as e:
47 | print(e)
48 | print('error', review)
49 | return [-1, -1]
50 |
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54 | parser.add_argument('-q', '--question')
55 | parser.add_argument('-c', '--context')
56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57 | parser.add_argument('-r', '--rule')
58 | parser.add_argument('-o', '--output')
59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60 | args = parser.parse_args()
61 |
62 | f_q = open(os.path.expanduser(args.question))
63 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66 |
67 | if os.path.isfile(os.path.expanduser(args.output)):
68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69 | else:
70 | cur_reviews = []
71 |
72 | review_file = open(f'{args.output}', 'a')
73 |
74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75 | image_to_context = {context['image']: context for context in context_list}
76 |
77 | handles = []
78 | idx = 0
79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80 | ques = json.loads(ques_js)
81 | ans1 = json.loads(ans1_js)
82 | ans2 = json.loads(ans2_js)
83 |
84 | inst = image_to_context[ques['image']]
85 | cap_str = '\n'.join(inst['captions'])
86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
87 |
88 | category = json.loads(ques_js)['category']
89 | if category in rule_dict:
90 | rule = rule_dict[category]
91 | else:
92 | assert False, f"Visual QA category not found in rule file: {category}."
93 | prompt = rule['prompt']
94 | role = rule['role']
95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
96 | f'[Question]\n{ques["text"]}\n\n'
97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
99 | f'[System]\n{prompt}\n\n')
100 | cur_js = {
101 | 'id': idx+1,
102 | 'question_id': ques['question_id'],
103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']),
104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
105 | 'category': category
106 | }
107 | if idx >= len(cur_reviews):
108 | review = get_eval(content, args.max_tokens)
109 | scores = parse_score(review)
110 | cur_js['content'] = review
111 | cur_js['tuple'] = scores
112 | review_file.write(json.dumps(cur_js) + '\n')
113 | review_file.flush()
114 | else:
115 | print(f'Skipping {idx} as we already have it.')
116 | idx += 1
117 | print(idx)
118 | review_file.close()
119 |
--------------------------------------------------------------------------------
/llava/eval/eval_pope.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | def eval_pope(answers, label_file):
6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
7 |
8 | for answer in answers:
9 | text = answer['text']
10 |
11 | # Only keep the first sentence
12 | if text.find('.') != -1:
13 | text = text.split('.')[0]
14 |
15 | text = text.replace(',', '')
16 | words = text.split(' ')
17 | if 'No' in words or 'not' in words or 'no' in words:
18 | answer['text'] = 'no'
19 | else:
20 | answer['text'] = 'yes'
21 |
22 | for i in range(len(label_list)):
23 | if label_list[i] == 'no':
24 | label_list[i] = 0
25 | else:
26 | label_list[i] = 1
27 |
28 | pred_list = []
29 | for answer in answers:
30 | if answer['text'] == 'no':
31 | pred_list.append(0)
32 | else:
33 | pred_list.append(1)
34 |
35 | pos = 1
36 | neg = 0
37 | yes_ratio = pred_list.count(1) / len(pred_list)
38 |
39 | TP, TN, FP, FN = 0, 0, 0, 0
40 | for pred, label in zip(pred_list, label_list):
41 | if pred == pos and label == pos:
42 | TP += 1
43 | elif pred == pos and label == neg:
44 | FP += 1
45 | elif pred == neg and label == neg:
46 | TN += 1
47 | elif pred == neg and label == pos:
48 | FN += 1
49 |
50 | print('TP\tFP\tTN\tFN\t')
51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
52 |
53 | precision = float(TP) / float(TP + FP)
54 | recall = float(TP) / float(TP + FN)
55 | f1 = 2*precision*recall / (precision + recall)
56 | acc = (TP + TN) / (TP + TN + FP + FN)
57 | print('Accuracy: {}'.format(acc))
58 | print('Precision: {}'.format(precision))
59 | print('Recall: {}'.format(recall))
60 | print('F1 score: {}'.format(f1))
61 | print('Yes ratio: {}'.format(yes_ratio))
62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )
63 |
64 | if __name__ == "__main__":
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument("--annotation-dir", type=str)
67 | parser.add_argument("--question-file", type=str)
68 | parser.add_argument("--result-file", type=str)
69 | args = parser.parse_args()
70 |
71 | questions = [json.loads(line) for line in open(args.question_file)]
72 | questions = {question['question_id']: question for question in questions}
73 | answers = [json.loads(q) for q in open(args.result_file)]
74 | for file in os.listdir(args.annotation_dir):
75 | assert file.startswith('coco_pope_')
76 | assert file.endswith('.json')
77 | category = file[10:-5]
78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
81 | print("====================================")
82 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 |
7 |
8 | def get_args():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--base-dir', type=str)
11 | parser.add_argument('--result-file', type=str)
12 | parser.add_argument('--output-file', type=str)
13 | parser.add_argument('--output-result', type=str)
14 | parser.add_argument('--split', type=str, default='test')
15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16 | return parser.parse_args()
17 |
18 |
19 | def convert_caps(results):
20 | fakecaps = []
21 | for result in results:
22 | image_id = result['question_id']
23 | caption = result['text']
24 | fakecaps.append({"image_id": int(image_id), "caption": caption})
25 | return fakecaps
26 |
27 |
28 | def get_pred_idx(prediction, choices, options):
29 | """
30 | Get the index (e.g. 2) from the prediction (e.g. 'C')
31 | """
32 | if prediction in options[:len(choices)]:
33 | return options.index(prediction)
34 | else:
35 | return -1
36 | return random.choice(range(len(choices)))
37 |
38 |
39 | if __name__ == "__main__":
40 | args = get_args()
41 |
42 | base_dir = args.base_dir
43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
44 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
45 | predictions = [json.loads(line) for line in open(args.result_file)]
46 | predictions = {pred['question_id']: pred for pred in predictions}
47 | split_problems = {idx: problems[idx] for idx in split_indices}
48 |
49 | results = {'correct': [], 'incorrect': []}
50 | sqa_results = {}
51 | sqa_results['acc'] = None
52 | sqa_results['correct'] = None
53 | sqa_results['count'] = None
54 | sqa_results['results'] = {}
55 | sqa_results['outputs'] = {}
56 |
57 | for prob_id, prob in split_problems.items():
58 | if prob_id not in predictions:
59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'}
60 | pred_text = 'FAILED'
61 | else:
62 | pred = predictions[prob_id]
63 | pred_text = pred['text']
64 |
65 | if pred_text in args.options:
66 | answer = pred_text
67 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
68 | answer = pred_text[0]
69 | else:
70 | pattern = re.compile(r'The answer is ([A-Z]).')
71 | res = pattern.findall(pred_text)
72 | if len(res) == 1:
73 | answer = res[0] # 'A', 'B', ...
74 | else:
75 | answer = "FAILED"
76 |
77 | pred_idx = get_pred_idx(answer, prob['choices'], args.options)
78 |
79 | analysis = {
80 | 'question_id': prob_id,
81 | 'parsed_ans': answer,
82 | 'ground_truth': args.options[prob['answer']],
83 | 'question': pred['prompt'],
84 | 'pred': pred_text,
85 | 'is_multimodal': '' in pred['prompt'],
86 | }
87 |
88 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
89 | sqa_results['outputs'][prob_id] = pred_text
90 |
91 | if pred_idx == prob['answer']:
92 | results['correct'].append(analysis)
93 | else:
94 | results['incorrect'].append(analysis)
95 |
96 | correct = len(results['correct'])
97 | total = len(results['correct']) + len(results['incorrect'])
98 |
99 | ###### IMG ######
100 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
101 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
102 | multimodal_total = multimodal_correct + multimodal_incorrect
103 | ###### IMG ######
104 |
105 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
106 |
107 | sqa_results['acc'] = correct / total * 100
108 | sqa_results['correct'] = correct
109 | sqa_results['count'] = total
110 |
111 | with open(args.output_file, 'w') as f:
112 | json.dump(results, f, indent=2)
113 | with open(args.output_result, 'w') as f:
114 | json.dump(sqa_results, f, indent=2)
115 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa_gpt4.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 | from collections import defaultdict
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--base-dir', type=str)
12 | parser.add_argument('--gpt4-result', type=str)
13 | parser.add_argument('--our-result', type=str)
14 | parser.add_argument('--split', type=str, default='test')
15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16 | return parser.parse_args()
17 |
18 |
19 | def convert_caps(results):
20 | fakecaps = []
21 | for result in results:
22 | image_id = result['question_id']
23 | caption = result['text']
24 | fakecaps.append({"image_id": int(image_id), "caption": caption})
25 | return fakecaps
26 |
27 |
28 | def get_pred_idx(prediction, choices, options):
29 | """
30 | Get the index (e.g. 2) from the prediction (e.g. 'C')
31 | """
32 | if prediction in options[:len(choices)]:
33 | return options.index(prediction)
34 | else:
35 | return random.choice(range(len(choices)))
36 |
37 |
38 | if __name__ == "__main__":
39 | args = get_args()
40 |
41 | base_dir = args.base_dir
42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
43 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
44 | our_predictions = [json.loads(line) for line in open(args.our_result)]
45 | our_predictions = {pred['question_id']: pred for pred in our_predictions}
46 | split_problems = {idx: problems[idx] for idx in split_indices}
47 |
48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
49 |
50 | results = defaultdict(lambda: 0)
51 |
52 | for prob_id, prob in split_problems.items():
53 | if prob_id not in our_predictions:
54 | continue
55 | if prob_id not in gpt4_predictions:
56 | continue
57 | our_pred = our_predictions[prob_id]['text']
58 | gpt4_pred = gpt4_predictions[prob_id]
59 |
60 | pattern = re.compile(r'The answer is ([A-Z]).')
61 | our_res = pattern.findall(our_pred)
62 | if len(our_res) == 1:
63 | our_answer = our_res[0] # 'A', 'B', ...
64 | else:
65 | our_answer = "FAILED"
66 | gpt4_res = pattern.findall(gpt4_pred)
67 | if len(gpt4_res) == 1:
68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ...
69 | else:
70 | gpt4_answer = "FAILED"
71 |
72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
74 |
75 | if gpt4_answer == 'FAILED':
76 | results['gpt4_failed'] += 1
77 | # continue
78 | gpt4_pred_idx = our_pred_idx
79 | # if our_pred_idx != prob['answer']:
80 | # print(our_predictions[prob_id]['prompt'])
81 | # print('-----------------')
82 | # print(f'LECTURE: {prob["lecture"]}')
83 | # print(f'SOLUTION: {prob["solution"]}')
84 | # print('=====================')
85 | else:
86 | # continue
87 | pass
88 | # gpt4_pred_idx = our_pred_idx
89 |
90 | if gpt4_pred_idx == prob['answer']:
91 | results['correct'] += 1
92 | else:
93 | results['incorrect'] += 1
94 |
95 |
96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
97 | results['correct_upperbound'] += 1
98 |
99 | correct = results['correct']
100 | total = results['correct'] + results['incorrect']
101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
104 |
105 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa_gpt4_requery.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 | from collections import defaultdict
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--base-dir', type=str)
12 | parser.add_argument('--gpt4-result', type=str)
13 | parser.add_argument('--requery-result', type=str)
14 | parser.add_argument('--our-result', type=str)
15 | parser.add_argument('--output-result', type=str)
16 | parser.add_argument('--split', type=str, default='test')
17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
18 | return parser.parse_args()
19 |
20 |
21 | def convert_caps(results):
22 | fakecaps = []
23 | for result in results:
24 | image_id = result['question_id']
25 | caption = result['text']
26 | fakecaps.append({"image_id": int(image_id), "caption": caption})
27 | return fakecaps
28 |
29 |
30 | def get_pred_idx(prediction, choices, options):
31 | """
32 | Get the index (e.g. 2) from the prediction (e.g. 'C')
33 | """
34 | if prediction in options[:len(choices)]:
35 | return options.index(prediction)
36 | else:
37 | return random.choice(range(len(choices)))
38 |
39 |
40 | if __name__ == "__main__":
41 | args = get_args()
42 |
43 | base_dir = args.base_dir
44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
45 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
46 | our_predictions = [json.loads(line) for line in open(args.our_result)]
47 | our_predictions = {pred['question_id']: pred for pred in our_predictions}
48 | split_problems = {idx: problems[idx] for idx in split_indices}
49 |
50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)]
51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions}
52 |
53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
54 |
55 | results = defaultdict(lambda: 0)
56 |
57 | sqa_results = {}
58 | sqa_results['acc'] = None
59 | sqa_results['correct'] = None
60 | sqa_results['count'] = None
61 | sqa_results['results'] = {}
62 | sqa_results['outputs'] = {}
63 |
64 | for prob_id, prob in split_problems.items():
65 | if prob_id not in our_predictions:
66 | assert False
67 | if prob_id not in gpt4_predictions:
68 | assert False
69 | our_pred = our_predictions[prob_id]['text']
70 | gpt4_pred = gpt4_predictions[prob_id]
71 | if prob_id not in requery_predictions:
72 | results['missing_requery'] += 1
73 | requery_pred = "MISSING"
74 | else:
75 | requery_pred = requery_predictions[prob_id]['text']
76 |
77 | pattern = re.compile(r'The answer is ([A-Z]).')
78 | our_res = pattern.findall(our_pred)
79 | if len(our_res) == 1:
80 | our_answer = our_res[0] # 'A', 'B', ...
81 | else:
82 | our_answer = "FAILED"
83 |
84 | requery_res = pattern.findall(requery_pred)
85 | if len(requery_res) == 1:
86 | requery_answer = requery_res[0] # 'A', 'B', ...
87 | else:
88 | requery_answer = "FAILED"
89 |
90 | gpt4_res = pattern.findall(gpt4_pred)
91 | if len(gpt4_res) == 1:
92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ...
93 | else:
94 | gpt4_answer = "FAILED"
95 |
96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options)
99 |
100 | results['total'] += 1
101 |
102 | if gpt4_answer == 'FAILED':
103 | results['gpt4_failed'] += 1
104 | if gpt4_pred_idx == prob['answer']:
105 | results['gpt4_correct'] += 1
106 | if our_pred_idx == prob['answer']:
107 | results['gpt4_ourvisual_correct'] += 1
108 | elif gpt4_pred_idx == prob['answer']:
109 | results['gpt4_correct'] += 1
110 | results['gpt4_ourvisual_correct'] += 1
111 |
112 | if our_pred_idx == prob['answer']:
113 | results['our_correct'] += 1
114 |
115 | if requery_answer == 'FAILED':
116 | sqa_results['results'][prob_id] = our_pred_idx
117 | if our_pred_idx == prob['answer']:
118 | results['requery_correct'] += 1
119 | else:
120 | sqa_results['results'][prob_id] = requery_pred_idx
121 | if requery_pred_idx == prob['answer']:
122 | results['requery_correct'] += 1
123 | else:
124 | print(f"""
125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
126 | Our ({our_answer}): {our_pred}
127 | GPT-4 ({gpt4_answer}): {gpt4_pred}
128 | Requery ({requery_answer}): {requery_pred}
129 | print("=====================================")
130 | """)
131 |
132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
133 | results['correct_upperbound'] += 1
134 |
135 | total = results['total']
136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%')
137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%')
138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%')
140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%')
141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
142 |
143 | sqa_results['acc'] = results["requery_correct"] / total * 100
144 | sqa_results['correct'] = results["requery_correct"]
145 | sqa_results['count'] = total
146 |
147 | with open(args.output_result, 'w') as f:
148 | json.dump(sqa_results, f, indent=2)
149 |
150 |
--------------------------------------------------------------------------------
/llava/eval/eval_textvqa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import re
5 |
6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--annotation-file', type=str)
12 | parser.add_argument('--result-file', type=str)
13 | parser.add_argument('--result-dir', type=str)
14 | return parser.parse_args()
15 |
16 |
17 | def prompt_processor(prompt):
18 | if prompt.startswith('OCR tokens: '):
19 | pattern = r"Question: (.*?) Short answer:"
20 | match = re.search(pattern, prompt, re.DOTALL)
21 | question = match.group(1)
22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
23 | if prompt.startswith('Reference OCR token:'):
24 | question = prompt.split('\n')[1]
25 | else:
26 | question = prompt.split('\n')[0]
27 | elif len(prompt.split('\n')) == 2:
28 | question = prompt.split('\n')[0]
29 | else:
30 | assert False
31 |
32 | return question.lower()
33 |
34 |
35 | def eval_single(annotation_file, result_file):
36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0]
37 | print(experiment_name)
38 | annotations = json.load(open(annotation_file))['data']
39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
40 | results = [json.loads(line) for line in open(result_file)]
41 |
42 | pred_list = []
43 | for result in results:
44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
45 | pred_list.append({
46 | "pred_answer": result['text'],
47 | "gt_answers": annotation['answers'],
48 | })
49 |
50 | evaluator = TextVQAAccuracyEvaluator()
51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
52 |
53 |
54 | if __name__ == "__main__":
55 | args = get_args()
56 |
57 | if args.result_file is not None:
58 | eval_single(args.annotation_file, args.result_file)
59 |
60 | if args.result_dir is not None:
61 | for result_file in sorted(os.listdir(args.result_dir)):
62 | if not result_file.endswith('.jsonl'):
63 | print(f'Skipping {result_file}')
64 | continue
65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
66 |
--------------------------------------------------------------------------------
/llava/eval/generate_webpage_data_from_table.py:
--------------------------------------------------------------------------------
1 | """Generate json file for webpage."""
2 | import json
3 | import os
4 | import re
5 |
6 | # models = ['llama', 'alpaca', 'gpt35', 'bard']
7 | models = ['vicuna']
8 |
9 |
10 | def read_jsonl(path: str, key: str=None):
11 | data = []
12 | with open(os.path.expanduser(path)) as f:
13 | for line in f:
14 | if not line:
15 | continue
16 | data.append(json.loads(line))
17 | if key is not None:
18 | data.sort(key=lambda x: x[key])
19 | data = {item[key]: item for item in data}
20 | return data
21 |
22 |
23 | def trim_hanging_lines(s: str, n: int) -> str:
24 | s = s.strip()
25 | for _ in range(n):
26 | s = s.split('\n', 1)[1].strip()
27 | return s
28 |
29 |
30 | if __name__ == '__main__':
31 | questions = read_jsonl('table/question.jsonl', key='question_id')
32 |
33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
39 |
40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')
45 |
46 | records = []
47 | for qid in questions.keys():
48 | r = {
49 | 'id': qid,
50 | 'category': questions[qid]['category'],
51 | 'question': questions[qid]['text'],
52 | 'answers': {
53 | # 'alpaca': alpaca_answers[qid]['text'],
54 | # 'llama': llama_answers[qid]['text'],
55 | # 'bard': bard_answers[qid]['text'],
56 | # 'gpt35': gpt35_answers[qid]['text'],
57 | 'vicuna': vicuna_answers[qid]['text'],
58 | 'ours': ours_answers[qid]['text'],
59 | },
60 | 'evaluations': {
61 | # 'alpaca': review_alpaca[qid]['text'],
62 | # 'llama': review_llama[qid]['text'],
63 | # 'bard': review_bard[qid]['text'],
64 | 'vicuna': review_vicuna[qid]['content'],
65 | # 'gpt35': review_gpt35[qid]['text'],
66 | },
67 | 'scores': {
68 | 'vicuna': review_vicuna[qid]['tuple'],
69 | # 'alpaca': review_alpaca[qid]['score'],
70 | # 'llama': review_llama[qid]['score'],
71 | # 'bard': review_bard[qid]['score'],
72 | # 'gpt35': review_gpt35[qid]['score'],
73 | },
74 | }
75 |
76 | # cleanup data
77 | cleaned_evals = {}
78 | for k, v in r['evaluations'].items():
79 | v = v.strip()
80 | lines = v.split('\n')
81 | # trim the first line if it's a pair of numbers
82 | if re.match(r'\d+[, ]+\d+', lines[0]):
83 | lines = lines[1:]
84 | v = '\n'.join(lines)
85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
86 |
87 | r['evaluations'] = cleaned_evals
88 | records.append(r)
89 |
90 | # Reorder the records, this is optional
91 | for r in records:
92 | if r['id'] <= 20:
93 | r['id'] += 60
94 | else:
95 | r['id'] -= 20
96 | for r in records:
97 | if r['id'] <= 50:
98 | r['id'] += 10
99 | elif 50 < r['id'] <= 60:
100 | r['id'] -= 50
101 | for r in records:
102 | if r['id'] == 7:
103 | r['id'] = 1
104 | elif r['id'] < 7:
105 | r['id'] += 1
106 |
107 | records.sort(key=lambda x: x['id'])
108 |
109 | # Write to file
110 | with open('webpage/data.json', 'w') as f:
111 | json.dump({'questions': records, 'models': models}, f, indent=2)
112 |
--------------------------------------------------------------------------------
/llava/eval/model_qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
3 | import torch
4 | import os
5 | import json
6 | from tqdm import tqdm
7 | import shortuuid
8 |
9 | from llava.conversation import default_conversation
10 | from llava.utils import disable_torch_init
11 |
12 |
13 | # new stopping implementation
14 | class KeywordsStoppingCriteria(StoppingCriteria):
15 | def __init__(self, keywords, tokenizer, input_ids):
16 | self.keywords = keywords
17 | self.tokenizer = tokenizer
18 | self.start_len = None
19 | self.input_ids = input_ids
20 |
21 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
22 | if self.start_len is None:
23 | self.start_len = self.input_ids.shape[1]
24 | else:
25 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
26 | for keyword in self.keywords:
27 | if keyword in outputs:
28 | return True
29 | return False
30 |
31 |
32 | @torch.inference_mode()
33 | def eval_model(model_name, questions_file, answers_file):
34 | # Model
35 | disable_torch_init()
36 | model_name = os.path.expanduser(model_name)
37 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
38 | model = AutoModelForCausalLM.from_pretrained(model_name,
39 | torch_dtype=torch.float16).cuda()
40 |
41 |
42 | ques_file = open(os.path.expanduser(questions_file), "r")
43 | ans_file = open(os.path.expanduser(answers_file), "w")
44 | for i, line in enumerate(tqdm(ques_file)):
45 | idx = json.loads(line)["question_id"]
46 | qs = json.loads(line)["text"]
47 | cat = json.loads(line)["category"]
48 | conv = default_conversation.copy()
49 | conv.append_message(conv.roles[0], qs)
50 | prompt = conv.get_prompt()
51 | inputs = tokenizer([prompt])
52 | input_ids = torch.as_tensor(inputs.input_ids).cuda()
53 | stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids)
54 | output_ids = model.generate(
55 | input_ids,
56 | do_sample=True,
57 | use_cache=True,
58 | temperature=0.7,
59 | max_new_tokens=1024,
60 | stopping_criteria=[stopping_criteria])
61 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
62 | try:
63 | index = outputs.index(conv.sep, len(prompt))
64 | except ValueError:
65 | outputs += conv.sep
66 | index = outputs.index(conv.sep, len(prompt))
67 |
68 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
69 | ans_id = shortuuid.uuid()
70 | ans_file.write(json.dumps({"question_id": idx,
71 | "text": outputs,
72 | "answer_id": ans_id,
73 | "model_id": model_name,
74 | "metadata": {}}) + "\n")
75 | ans_file.flush()
76 | ans_file.close()
77 |
78 | if __name__ == "__main__":
79 | parser = argparse.ArgumentParser()
80 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
81 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
82 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
83 | args = parser.parse_args()
84 |
85 | eval_model(args.model_name, args.question_file, args.answers_file)
86 |
--------------------------------------------------------------------------------
/llava/eval/model_vqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import shortuuid
7 |
8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9 | from llava.conversation import conv_templates, SeparatorStyle
10 | from llava.model.builder import load_pretrained_model
11 | from llava.utils import disable_torch_init
12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13 |
14 | from PIL import Image
15 | import math
16 |
17 |
18 | def split_list(lst, n):
19 | """Split a list into n (roughly) equal-sized chunks"""
20 | chunk_size = math.ceil(len(lst) / n) # integer division
21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22 |
23 |
24 | def get_chunk(lst, n, k):
25 | chunks = split_list(lst, n)
26 | return chunks[k]
27 |
28 |
29 | def eval_model(args):
30 | # Model
31 | disable_torch_init()
32 | model_path = os.path.expanduser(args.model_path)
33 | model_name = get_model_name_from_path(model_path)
34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35 |
36 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
38 | answers_file = os.path.expanduser(args.answers_file)
39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
40 | ans_file = open(answers_file, "w")
41 | for line in tqdm(questions):
42 | idx = line["question_id"]
43 | image_file = line["image"]
44 | qs = line["text"]
45 | cur_prompt = qs
46 | if model.config.mm_use_im_start_end:
47 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
48 | else:
49 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
50 |
51 | conv = conv_templates[args.conv_mode].copy()
52 | conv.append_message(conv.roles[0], qs)
53 | conv.append_message(conv.roles[1], None)
54 | prompt = conv.get_prompt()
55 |
56 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
57 |
58 | image = Image.open(os.path.join(args.image_folder, image_file))
59 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
60 |
61 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
62 | keywords = [stop_str]
63 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
64 |
65 | with torch.inference_mode():
66 | output_ids = model.generate(
67 | input_ids,
68 | images=image_tensor.unsqueeze(0).half().cuda(),
69 | do_sample=True if args.temperature > 0 else False,
70 | temperature=args.temperature,
71 | top_p=args.top_p,
72 | num_beams=args.num_beams,
73 | # no_repeat_ngram_size=3,
74 | max_new_tokens=1024,
75 | use_cache=True)
76 |
77 | input_token_len = input_ids.shape[1]
78 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
79 | if n_diff_input_output > 0:
80 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
81 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
82 | outputs = outputs.strip()
83 | if outputs.endswith(stop_str):
84 | outputs = outputs[:-len(stop_str)]
85 | outputs = outputs.strip()
86 |
87 | ans_id = shortuuid.uuid()
88 | ans_file.write(json.dumps({"question_id": idx,
89 | "prompt": cur_prompt,
90 | "text": outputs,
91 | "answer_id": ans_id,
92 | "model_id": model_name,
93 | "metadata": {}}) + "\n")
94 | ans_file.flush()
95 | ans_file.close()
96 |
97 | if __name__ == "__main__":
98 | parser = argparse.ArgumentParser()
99 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
100 | parser.add_argument("--model-base", type=str, default=None)
101 | parser.add_argument("--image-folder", type=str, default="")
102 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
103 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
104 | parser.add_argument("--conv-mode", type=str, default="llava_v1")
105 | parser.add_argument("--num-chunks", type=int, default=1)
106 | parser.add_argument("--chunk-idx", type=int, default=0)
107 | parser.add_argument("--temperature", type=float, default=0.2)
108 | parser.add_argument("--top_p", type=float, default=None)
109 | parser.add_argument("--num_beams", type=int, default=1)
110 | args = parser.parse_args()
111 |
112 | eval_model(args)
113 |
--------------------------------------------------------------------------------
/llava/eval/model_vqa_loader.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import shortuuid
7 |
8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9 | from llava.conversation import conv_templates, SeparatorStyle
10 | from llava.model.builder import load_pretrained_model
11 | from llava.utils import disable_torch_init
12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13 | from torch.utils.data import Dataset, DataLoader
14 |
15 | from PIL import Image
16 | import math
17 |
18 |
19 | def split_list(lst, n):
20 | """Split a list into n (roughly) equal-sized chunks"""
21 | chunk_size = math.ceil(len(lst) / n) # integer division
22 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
23 |
24 |
25 | def get_chunk(lst, n, k):
26 | chunks = split_list(lst, n)
27 | return chunks[k]
28 |
29 |
30 | # Custom dataset class
31 | class CustomDataset(Dataset):
32 | def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
33 | self.questions = questions
34 | self.image_folder = image_folder
35 | self.tokenizer = tokenizer
36 | self.image_processor = image_processor
37 | self.model_config = model_config
38 |
39 | def __getitem__(self, index):
40 | line = self.questions[index]
41 | image_file = line["image"]
42 | qs = line["text"]
43 | if self.model_config.mm_use_im_start_end:
44 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
45 | else:
46 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
47 |
48 | conv = conv_templates[args.conv_mode].copy()
49 | conv.append_message(conv.roles[0], qs)
50 | conv.append_message(conv.roles[1], None)
51 | prompt = conv.get_prompt()
52 |
53 | image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
54 | image_tensor = process_images([image], self.image_processor, self.model_config)[0]
55 |
56 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
57 |
58 | return input_ids, image_tensor
59 |
60 | def __len__(self):
61 | return len(self.questions)
62 |
63 |
64 | # DataLoader
65 | def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
66 | assert batch_size == 1, "batch_size must be 1"
67 | dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
68 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False)
69 | return data_loader
70 |
71 |
72 | def eval_model(args):
73 | # Model
74 | disable_torch_init()
75 | model_path = os.path.expanduser(args.model_path)
76 | model_name = get_model_name_from_path(model_path)
77 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
78 |
79 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
80 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
81 | answers_file = os.path.expanduser(args.answers_file)
82 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
83 | ans_file = open(answers_file, "w")
84 |
85 | if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
86 | args.conv_mode = args.conv_mode + '_mmtag'
87 | print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
88 |
89 | data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
90 |
91 | for (input_ids, image_tensor), line in tqdm(zip(data_loader, questions), total=len(questions)):
92 | idx = line["question_id"]
93 | cur_prompt = line["text"]
94 |
95 | input_ids = input_ids.to(device='cuda', non_blocking=True)
96 |
97 | with torch.inference_mode():
98 | output_ids = model.generate(
99 | input_ids,
100 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
101 | do_sample=True if args.temperature > 0 else False,
102 | temperature=args.temperature,
103 | top_p=args.top_p,
104 | num_beams=args.num_beams,
105 | max_new_tokens=args.max_new_tokens,
106 | use_cache=True)
107 |
108 | input_token_len = input_ids.shape[1]
109 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
110 | if n_diff_input_output > 0:
111 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
112 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
113 | outputs = outputs.strip()
114 |
115 | ans_id = shortuuid.uuid()
116 | ans_file.write(json.dumps({"question_id": idx,
117 | "prompt": cur_prompt,
118 | "text": outputs,
119 | "answer_id": ans_id,
120 | "model_id": model_name,
121 | "metadata": {}}) + "\n")
122 | # ans_file.flush()
123 | ans_file.close()
124 |
125 | if __name__ == "__main__":
126 | parser = argparse.ArgumentParser()
127 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
128 | parser.add_argument("--model-base", type=str, default=None)
129 | parser.add_argument("--image-folder", type=str, default="")
130 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
131 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
132 | parser.add_argument("--conv-mode", type=str, default="llava_v1")
133 | parser.add_argument("--num-chunks", type=int, default=1)
134 | parser.add_argument("--chunk-idx", type=int, default=0)
135 | parser.add_argument("--temperature", type=float, default=0.2)
136 | parser.add_argument("--top_p", type=float, default=None)
137 | parser.add_argument("--num_beams", type=int, default=1)
138 | parser.add_argument("--max_new_tokens", type=int, default=128)
139 | args = parser.parse_args()
140 |
141 | eval_model(args)
142 |
--------------------------------------------------------------------------------
/llava/eval/model_vqa_mmbench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | import pandas as pd
6 | from tqdm import tqdm
7 | import shortuuid
8 |
9 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
10 | from llava.conversation import conv_templates, SeparatorStyle
11 | from llava.model.builder import load_pretrained_model
12 | from llava.utils import disable_torch_init
13 | from llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path
14 |
15 | from PIL import Image
16 | import math
17 |
18 |
19 | all_options = ['A', 'B', 'C', 'D']
20 |
21 |
22 | def split_list(lst, n):
23 | """Split a list into n (roughly) equal-sized chunks"""
24 | chunk_size = math.ceil(len(lst) / n) # integer division
25 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
26 |
27 |
28 | def get_chunk(lst, n, k):
29 | chunks = split_list(lst, n)
30 | return chunks[k]
31 |
32 |
33 | def is_none(value):
34 | if value is None:
35 | return True
36 | if type(value) is float and math.isnan(value):
37 | return True
38 | if type(value) is str and value.lower() == 'nan':
39 | return True
40 | if type(value) is str and value.lower() == 'none':
41 | return True
42 | return False
43 |
44 | def get_options(row, options):
45 | parsed_options = []
46 | for option in options:
47 | option_value = row[option]
48 | if is_none(option_value):
49 | break
50 | parsed_options.append(option_value)
51 | return parsed_options
52 |
53 |
54 | def eval_model(args):
55 | # Model
56 | disable_torch_init()
57 | model_path = os.path.expanduser(args.model_path)
58 | model_name = get_model_name_from_path(model_path)
59 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
60 |
61 | questions = pd.read_table(os.path.expanduser(args.question_file))
62 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
63 | answers_file = os.path.expanduser(args.answers_file)
64 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
65 | ans_file = open(answers_file, "w")
66 |
67 | if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
68 | args.conv_mode = args.conv_mode + '_mmtag'
69 | print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
70 |
71 | for index, row in tqdm(questions.iterrows(), total=len(questions)):
72 | options = get_options(row, all_options)
73 | cur_option_char = all_options[:len(options)]
74 |
75 | if args.all_rounds:
76 | num_rounds = len(options)
77 | else:
78 | num_rounds = 1
79 |
80 | for round_idx in range(num_rounds):
81 | idx = row['index']
82 | question = row['question']
83 | hint = row['hint']
84 | image = load_image_from_base64(row['image'])
85 | if not is_none(hint):
86 | question = hint + '\n' + question
87 | for option_char, option in zip(all_options[:len(options)], options):
88 | question = question + '\n' + option_char + '. ' + option
89 | qs = cur_prompt = question
90 | if model.config.mm_use_im_start_end:
91 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
92 | else:
93 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
94 |
95 | if args.single_pred_prompt:
96 | if args.lang == 'cn':
97 | qs = qs + '\n' + "请直接回答选项字母。"
98 | else:
99 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
100 |
101 | conv = conv_templates[args.conv_mode].copy()
102 | conv.append_message(conv.roles[0], qs)
103 | conv.append_message(conv.roles[1], None)
104 | prompt = conv.get_prompt()
105 |
106 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
107 |
108 | image_tensor = process_images([image], image_processor, model.config)[0]
109 | # image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
110 |
111 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
112 |
113 | with torch.inference_mode():
114 | output_ids = model.generate(
115 | input_ids,
116 | images=image_tensor.unsqueeze(0).half().cuda(),
117 | do_sample=True if args.temperature > 0 else False,
118 | temperature=args.temperature,
119 | top_p=args.top_p,
120 | num_beams=args.num_beams,
121 | # no_repeat_ngram_size=3,
122 | max_new_tokens=1024,
123 | use_cache=True)
124 |
125 | input_token_len = input_ids.shape[1]
126 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
127 | if n_diff_input_output > 0:
128 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
129 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
130 | outputs = outputs.strip()
131 | if outputs.endswith(stop_str):
132 | outputs = outputs[:-len(stop_str)]
133 | outputs = outputs.strip()
134 |
135 | ans_id = shortuuid.uuid()
136 | ans_file.write(json.dumps({"question_id": idx,
137 | "round_id": round_idx,
138 | "prompt": cur_prompt,
139 | "text": outputs,
140 | "options": options,
141 | "option_char": cur_option_char,
142 | "answer_id": ans_id,
143 | "model_id": model_name,
144 | "metadata": {}}) + "\n")
145 | ans_file.flush()
146 |
147 | # rotate options
148 | options = options[1:] + options[:1]
149 | cur_option_char = cur_option_char[1:] + cur_option_char[:1]
150 | ans_file.close()
151 |
152 | if __name__ == "__main__":
153 | parser = argparse.ArgumentParser()
154 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
155 | parser.add_argument("--model-base", type=str, default=None)
156 | parser.add_argument("--image-folder", type=str, default="")
157 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
158 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
159 | parser.add_argument("--conv-mode", type=str, default="llava_v1")
160 | parser.add_argument("--num-chunks", type=int, default=1)
161 | parser.add_argument("--chunk-idx", type=int, default=0)
162 | parser.add_argument("--temperature", type=float, default=0.2)
163 | parser.add_argument("--top_p", type=float, default=None)
164 | parser.add_argument("--num_beams", type=int, default=1)
165 | parser.add_argument("--all-rounds", action="store_true")
166 | parser.add_argument("--single-pred-prompt", action="store_true")
167 | parser.add_argument("--lang", type=str, default="en")
168 | args = parser.parse_args()
169 |
170 | eval_model(args)
171 |
--------------------------------------------------------------------------------
/llava/eval/model_vqa_qbench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | from tqdm import tqdm
4 | import json
5 |
6 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
7 | from llava.conversation import conv_templates, SeparatorStyle
8 | from llava.model.builder import load_pretrained_model
9 | from llava.utils import disable_torch_init
10 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
11 |
12 | from PIL import Image
13 |
14 | import requests
15 | from PIL import Image
16 | from io import BytesIO
17 |
18 |
19 | def load_image(image_file):
20 | if image_file.startswith('http') or image_file.startswith('https'):
21 | response = requests.get(image_file)
22 | image = Image.open(BytesIO(response.content)).convert('RGB')
23 | else:
24 | image = Image.open(image_file).convert('RGB')
25 | return image
26 |
27 |
28 | def eval_model(args):
29 | # Model
30 | disable_torch_init()
31 |
32 | model_name = get_model_name_from_path(args.model_path)
33 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, True)
34 |
35 |
36 |
37 |
38 | with open(args.questions_file) as f:
39 | llvqa_data = json.load(f)
40 |
41 | for i, llddata in enumerate(tqdm(llvqa_data)):
42 | filename = llddata["img_path"]
43 | if args.lang == "en":
44 | message = llddata["question"] + "\nChoose between one of the options as follows:\n"
45 | elif args.lang == "zh":
46 | message = llddata["question"] + "\在下列选项中选择一个:\n"
47 | else:
48 | raise NotImplementedError("Q-Bench does not support languages other than English (en) and Chinese (zh) yet. Contact us (https://github.com/VQAssessment/Q-Bench/) to convert Q-Bench into more languages.")
49 | for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]):
50 | message += f"{choice} {ans}\n"
51 | qs = message
52 |
53 | if model.config.mm_use_im_start_end:
54 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
55 | else:
56 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
57 |
58 | if 'llama-2' in model_name.lower():
59 | conv_mode = "llava_llama_2"
60 | elif "v1" in model_name.lower():
61 | conv_mode = "llava_v1"
62 | elif "mpt" in model_name.lower():
63 | conv_mode = "mpt"
64 | else:
65 | conv_mode = "llava_v0"
66 |
67 | if args.conv_mode is not None and conv_mode != args.conv_mode:
68 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
69 | else:
70 | args.conv_mode = conv_mode
71 |
72 | conv = conv_templates[args.conv_mode].copy()
73 | conv.append_message(conv.roles[0], qs)
74 | conv.append_message(conv.roles[1], None)
75 | prompt = conv.get_prompt()
76 |
77 | image = load_image(args.image_folder + filename)
78 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()
79 |
80 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
81 |
82 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
83 | keywords = [stop_str]
84 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
85 |
86 |
87 | with torch.inference_mode():
88 | output_ids = model.generate(
89 | input_ids,
90 | images=image_tensor,
91 | num_beams=1,
92 | do_sample=False,
93 | temperature=0,
94 | max_new_tokens=1024,
95 | use_cache=True,
96 | stopping_criteria=[stopping_criteria])
97 |
98 | input_token_len = input_ids.shape[1]
99 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
100 | if n_diff_input_output > 0:
101 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
102 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
103 | outputs = outputs.strip()
104 | if outputs.endswith(stop_str):
105 | outputs = outputs[:-len(stop_str)]
106 | outputs = outputs.strip()
107 | llddata["response"] = outputs
108 | with open(args.answers_file, "a") as wf:
109 | json.dump(llddata, wf)
110 |
111 | if __name__ == "__main__":
112 | parser = argparse.ArgumentParser()
113 | parser.add_argument("--model-path", type=str, default="llava-v1.5")
114 | parser.add_argument("--model-base", type=str, default=None)
115 | parser.add_argument("--image-folder", type=str, default="./playground/data/qbench/images_llvisionqa")
116 | parser.add_argument("--questions-file", type=str, default="./playground/data/qbench/llvisionqa_dev.json")
117 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
118 | parser.add_argument("--conv-mode", type=str, default="llava_v1")
119 | parser.add_argument("--lang", type=str, default="en")
120 | args = parser.parse_args()
121 |
122 | eval_model(args)
123 |
--------------------------------------------------------------------------------
/llava/eval/model_vqa_science.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import shortuuid
7 |
8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9 | from llava.conversation import conv_templates, SeparatorStyle
10 | from llava.model.builder import load_pretrained_model
11 | from llava.utils import disable_torch_init
12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
13 |
14 | from PIL import Image
15 | import math
16 |
17 |
18 | def split_list(lst, n):
19 | """Split a list into n (roughly) equal-sized chunks"""
20 | chunk_size = math.ceil(len(lst) / n) # integer division
21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22 |
23 |
24 | def get_chunk(lst, n, k):
25 | chunks = split_list(lst, n)
26 | return chunks[k]
27 |
28 |
29 | def eval_model(args):
30 | # Model
31 | disable_torch_init()
32 | model_path = os.path.expanduser(args.model_path)
33 | model_name = get_model_name_from_path(model_path)
34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35 |
36 | questions = json.load(open(os.path.expanduser(args.question_file), "r"))
37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
38 | answers_file = os.path.expanduser(args.answers_file)
39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
40 | ans_file = open(answers_file, "w")
41 | for i, line in enumerate(tqdm(questions)):
42 | idx = line["id"]
43 | question = line['conversations'][0]
44 | qs = question['value'].replace('', '').strip()
45 | cur_prompt = qs
46 |
47 | if 'image' in line:
48 | image_file = line["image"]
49 | image = Image.open(os.path.join(args.image_folder, image_file))
50 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
51 | images = image_tensor.unsqueeze(0).half().cuda()
52 | if getattr(model.config, 'mm_use_im_start_end', False):
53 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
54 | else:
55 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
56 | cur_prompt = '' + '\n' + cur_prompt
57 | else:
58 | images = None
59 |
60 | if args.single_pred_prompt:
61 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
62 | cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
63 |
64 | conv = conv_templates[args.conv_mode].copy()
65 | conv.append_message(conv.roles[0], qs)
66 | conv.append_message(conv.roles[1], None)
67 | prompt = conv.get_prompt()
68 |
69 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
70 |
71 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
72 | keywords = [stop_str]
73 | stopping_criteria = [KeywordsStoppingCriteria(keywords, tokenizer, input_ids)] if conv.version == "v0" else None
74 |
75 | with torch.inference_mode():
76 | output_ids = model.generate(
77 | input_ids,
78 | images=images,
79 | do_sample=True if args.temperature > 0 else False,
80 | temperature=args.temperature,
81 | max_new_tokens=1024,
82 | use_cache=True,
83 | stopping_criteria=stopping_criteria,
84 | )
85 |
86 | input_token_len = input_ids.shape[1]
87 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
88 | if n_diff_input_output > 0:
89 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
90 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
91 | outputs = outputs.strip()
92 | if outputs.endswith(stop_str):
93 | outputs = outputs[:-len(stop_str)]
94 | outputs = outputs.strip()
95 |
96 | # prompt for answer
97 | if args.answer_prompter:
98 | outputs_reasoning = outputs
99 | input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
100 |
101 | with torch.inference_mode():
102 | output_ids = model.generate(
103 | input_ids,
104 | images=images,
105 | do_sample=True if args.temperature > 0 else False,
106 | temperature=args.temperature,
107 | max_new_tokens=64,
108 | use_cache=True,
109 | stopping_criteria=[stopping_criteria])
110 |
111 | input_token_len = input_ids.shape[1]
112 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
113 | if n_diff_input_output > 0:
114 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
115 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
116 | outputs = outputs.strip()
117 | if outputs.endswith(stop_str):
118 | outputs = outputs[:-len(stop_str)]
119 | outputs = outputs.strip()
120 | outputs = outputs_reasoning + '\n The answer is ' + outputs
121 |
122 | ans_id = shortuuid.uuid()
123 | ans_file.write(json.dumps({"question_id": idx,
124 | "prompt": cur_prompt,
125 | "text": outputs,
126 | "answer_id": ans_id,
127 | "model_id": model_name,
128 | "metadata": {}}) + "\n")
129 | ans_file.flush()
130 | ans_file.close()
131 |
132 | if __name__ == "__main__":
133 | parser = argparse.ArgumentParser()
134 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
135 | parser.add_argument("--model-base", type=str, default=None)
136 | parser.add_argument("--image-folder", type=str, default="")
137 | parser.add_argument("--question-file", type=str, default="tables/question.json")
138 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
139 | parser.add_argument("--conv-mode", type=str, default="llava_v0")
140 | parser.add_argument("--num-chunks", type=int, default=1)
141 | parser.add_argument("--chunk-idx", type=int, default=0)
142 | parser.add_argument("--temperature", type=float, default=0.2)
143 | parser.add_argument("--answer-prompter", action="store_true")
144 | parser.add_argument("--single-pred-prompt", action="store_true")
145 | args = parser.parse_args()
146 |
147 | eval_model(args)
148 |
--------------------------------------------------------------------------------
/llava/eval/qa_baseline_gpt35.py:
--------------------------------------------------------------------------------
1 | """Generate answers with GPT-3.5"""
2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work
3 | import argparse
4 | import json
5 | import os
6 | import time
7 | import concurrent.futures
8 |
9 | import openai
10 | import tqdm
11 | import shortuuid
12 |
13 | MODEL = 'gpt-3.5-turbo'
14 | MODEL_ID = 'gpt-3.5-turbo:20230327'
15 |
16 | def get_answer(question_id: int, question: str, max_tokens: int):
17 | ans = {
18 | 'answer_id': shortuuid.uuid(),
19 | 'question_id': question_id,
20 | 'model_id': MODEL_ID,
21 | }
22 | for _ in range(3):
23 | try:
24 | response = openai.ChatCompletion.create(
25 | model=MODEL,
26 | messages=[{
27 | 'role': 'system',
28 | 'content': 'You are a helpful assistant.'
29 | }, {
30 | 'role': 'user',
31 | 'content': question,
32 | }],
33 | max_tokens=max_tokens,
34 | )
35 | ans['text'] = response['choices'][0]['message']['content']
36 | return ans
37 | except Exception as e:
38 | print('[ERROR]', e)
39 | ans['text'] = '#ERROR#'
40 | time.sleep(1)
41 | return ans
42 |
43 |
44 | if __name__ == '__main__':
45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
46 | parser.add_argument('-q', '--question')
47 | parser.add_argument('-o', '--output')
48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
49 | args = parser.parse_args()
50 |
51 | questions_dict = {}
52 | with open(os.path.expanduser(args.question)) as f:
53 | for line in f:
54 | if not line:
55 | continue
56 | q = json.loads(line)
57 | questions_dict[q['question_id']] = q['text']
58 |
59 | answers = []
60 |
61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
62 | futures = []
63 | for qid, question in questions_dict.items():
64 | future = executor.submit(get_answer, qid, question, args.max_tokens)
65 | futures.append(future)
66 |
67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
68 | answers.append(future.result())
69 |
70 | answers.sort(key=lambda x: x['question_id'])
71 |
72 | with open(os.path.expanduser(args.output), 'w') as f:
73 | table = [json.dumps(ans) for ans in answers]
74 | f.write('\n'.join(table))
75 |
--------------------------------------------------------------------------------
/llava/eval/run_llava.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from llava.constants import (
5 | IMAGE_TOKEN_INDEX,
6 | DEFAULT_IMAGE_TOKEN,
7 | DEFAULT_IM_START_TOKEN,
8 | DEFAULT_IM_END_TOKEN,
9 | IMAGE_PLACEHOLDER,
10 | )
11 | from llava.conversation import conv_templates, SeparatorStyle
12 | from llava.model.builder import load_pretrained_model
13 | from llava.utils import disable_torch_init
14 | from llava.mm_utils import (
15 | process_images,
16 | tokenizer_image_token,
17 | get_model_name_from_path,
18 | KeywordsStoppingCriteria,
19 | )
20 |
21 | from PIL import Image
22 |
23 | import requests
24 | from PIL import Image
25 | from io import BytesIO
26 | import re
27 |
28 |
29 | def image_parser(args):
30 | out = args.image_file.split(args.sep)
31 | return out
32 |
33 |
34 | def load_image(image_file):
35 | if image_file.startswith("http") or image_file.startswith("https"):
36 | response = requests.get(image_file)
37 | image = Image.open(BytesIO(response.content)).convert("RGB")
38 | else:
39 | image = Image.open(image_file).convert("RGB")
40 | return image
41 |
42 |
43 | def load_images(image_files):
44 | out = []
45 | for image_file in image_files:
46 | image = load_image(image_file)
47 | out.append(image)
48 | return out
49 |
50 |
51 | def eval_model(args):
52 | # Model
53 | disable_torch_init()
54 |
55 | model_name = get_model_name_from_path(args.model_path)
56 | tokenizer, model, image_processor, context_len = load_pretrained_model(
57 | args.model_path, args.model_base, model_name
58 | )
59 |
60 | qs = args.query
61 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
62 | if IMAGE_PLACEHOLDER in qs:
63 | if model.config.mm_use_im_start_end:
64 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
65 | else:
66 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
67 | else:
68 | if model.config.mm_use_im_start_end:
69 | qs = image_token_se + "\n" + qs
70 | else:
71 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
72 |
73 | if "llama-2" in model_name.lower():
74 | conv_mode = "llava_llama_2"
75 | elif "v1" in model_name.lower():
76 | conv_mode = "llava_v1"
77 | elif "mpt" in model_name.lower():
78 | conv_mode = "mpt"
79 | else:
80 | conv_mode = "llava_v0"
81 |
82 | if args.conv_mode is not None and conv_mode != args.conv_mode:
83 | print(
84 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
85 | conv_mode, args.conv_mode, args.conv_mode
86 | )
87 | )
88 | else:
89 | args.conv_mode = conv_mode
90 |
91 | conv = conv_templates[args.conv_mode].copy()
92 | conv.append_message(conv.roles[0], qs)
93 | conv.append_message(conv.roles[1], None)
94 | prompt = conv.get_prompt()
95 |
96 | image_files = image_parser(args)
97 | images = load_images(image_files)
98 | images_tensor = process_images(
99 | images,
100 | image_processor,
101 | model.config
102 | ).to(model.device, dtype=torch.float16)
103 |
104 | input_ids = (
105 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
106 | .unsqueeze(0)
107 | .cuda()
108 | )
109 |
110 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
111 | keywords = [stop_str]
112 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
113 |
114 | with torch.inference_mode():
115 | output_ids = model.generate(
116 | input_ids,
117 | images=images_tensor,
118 | do_sample=True if args.temperature > 0 else False,
119 | temperature=args.temperature,
120 | top_p=args.top_p,
121 | num_beams=args.num_beams,
122 | max_new_tokens=args.max_new_tokens,
123 | use_cache=True,
124 | stopping_criteria=[stopping_criteria],
125 | )
126 |
127 | input_token_len = input_ids.shape[1]
128 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
129 | if n_diff_input_output > 0:
130 | print(
131 | f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
132 | )
133 | outputs = tokenizer.batch_decode(
134 | output_ids[:, input_token_len:], skip_special_tokens=True
135 | )[0]
136 | outputs = outputs.strip()
137 | if outputs.endswith(stop_str):
138 | outputs = outputs[: -len(stop_str)]
139 | outputs = outputs.strip()
140 | print(outputs)
141 |
142 |
143 | if __name__ == "__main__":
144 | parser = argparse.ArgumentParser()
145 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
146 | parser.add_argument("--model-base", type=str, default=None)
147 | parser.add_argument("--image-file", type=str, required=True)
148 | parser.add_argument("--query", type=str, required=True)
149 | parser.add_argument("--conv-mode", type=str, default=None)
150 | parser.add_argument("--sep", type=str, default=",")
151 | parser.add_argument("--temperature", type=float, default=0.2)
152 | parser.add_argument("--top_p", type=float, default=None)
153 | parser.add_argument("--num_beams", type=int, default=1)
154 | parser.add_argument("--max_new_tokens", type=int, default=512)
155 | args = parser.parse_args()
156 |
157 | eval_model(args)
158 |
--------------------------------------------------------------------------------
/llava/eval/summarize_gpt_review.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 |
7 | import argparse
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
11 | parser.add_argument('-d', '--dir', default=None)
12 | parser.add_argument('-v', '--version', default=None)
13 | parser.add_argument('-s', '--select', nargs='*', default=None)
14 | parser.add_argument('-f', '--files', nargs='*', default=[])
15 | parser.add_argument('-i', '--ignore', nargs='*', default=[])
16 | return parser.parse_args()
17 |
18 |
19 | if __name__ == '__main__':
20 | args = parse_args()
21 |
22 | if args.ignore is not None:
23 | args.ignore = [int(x) for x in args.ignore]
24 |
25 | if len(args.files) > 0:
26 | review_files = args.files
27 | else:
28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
29 |
30 | for review_file in sorted(review_files):
31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
32 | if args.select is not None and any(x not in config for x in args.select):
33 | continue
34 | if '0613' in config:
35 | version = '0613'
36 | else:
37 | version = '0314'
38 | if args.version is not None and args.version != version:
39 | continue
40 | scores = defaultdict(list)
41 | print(config)
42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
43 | for review_str in f:
44 | review = json.loads(review_str)
45 | if review['question_id'] in args.ignore:
46 | continue
47 | if 'category' in review:
48 | scores[review['category']].append(review['tuple'])
49 | scores['all'].append(review['tuple'])
50 | else:
51 | if 'tuple' in review:
52 | scores['all'].append(review['tuple'])
53 | else:
54 | scores['all'].append(review['score'])
55 | for k, v in sorted(scores.items()):
56 | stats = np.asarray(v).mean(0).tolist()
57 | stats = [round(x, 3) for x in stats]
58 | # print(k, stats, round(stats[1]/stats[0]*100, 1))
59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1))
60 | print('=================================')
61 |
--------------------------------------------------------------------------------
/llava/eval/table/model.jsonl:
--------------------------------------------------------------------------------
1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"}
2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"}
3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"}
4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"}
5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"}
6 |
--------------------------------------------------------------------------------
/llava/eval/table/prompt.jsonl:
--------------------------------------------------------------------------------
1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"}
2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"}
3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"}
4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"}
5 |
--------------------------------------------------------------------------------
/llava/eval/table/reviewer.jsonl:
--------------------------------------------------------------------------------
1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"}
2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"}
3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
5 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/alpaca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/llava/eval/webpage/figures/alpaca.png
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/bard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/llava/eval/webpage/figures/bard.jpg
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/chatgpt.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/llama.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/llava/eval/webpage/figures/llama.jpg
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/vicuna.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/llava/eval/webpage/figures/vicuna.jpeg
--------------------------------------------------------------------------------
/llava/eval/webpage/styles.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
3 | background-color: #f8f9fa;
4 | }
5 |
6 | .navbar-dark .navbar-nav .nav-link {
7 | color: #f1cf68;
8 | font-size: 1.1rem;
9 | padding: 0.5rem 0.6rem;
10 | }
11 |
12 | .card-header {
13 | font-weight: bold;
14 | }
15 |
16 | .card {
17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
18 | transition: 0.3s;
19 | }
20 |
21 | .card:hover {
22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
23 | }
24 |
25 | button {
26 | transition: background-color 0.3s;
27 | }
28 |
29 | button:hover {
30 | background-color: #007bff;
31 | }
32 |
33 | @media (max-width: 767px) {
34 | .form-row .form-group {
35 | margin-bottom: 10px;
36 | }
37 | }
38 |
39 | /* Extra styles */
40 |
41 | .expandable-card .card-text-container {
42 | max-height: 200px;
43 | overflow-y: hidden;
44 | position: relative;
45 | }
46 |
47 | .expandable-card.expanded .card-text-container {
48 | max-height: none;
49 | }
50 |
51 | .expand-btn {
52 | position: relative;
53 | display: none;
54 | background-color: rgba(255, 255, 255, 0.8);
55 | color: #510c75;
56 | border-color: transparent;
57 | }
58 |
59 | .expand-btn:hover {
60 | background-color: rgba(200, 200, 200, 0.8);
61 | text-decoration: none;
62 | border-color: transparent;
63 | color: #510c75;
64 | }
65 |
66 | .expand-btn:focus {
67 | outline: none;
68 | text-decoration: none;
69 | }
70 |
71 | .expandable-card:not(.expanded) .card-text-container:after {
72 | content: "";
73 | position: absolute;
74 | bottom: 0;
75 | left: 0;
76 | width: 100%;
77 | height: 90px;
78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1));
79 | }
80 |
81 | .expandable-card:not(.expanded) .expand-btn {
82 | margin-top: -40px;
83 | }
84 |
85 | .card-body {
86 | padding-bottom: 5px;
87 | }
88 |
89 | .vertical-flex-layout {
90 | justify-content: center;
91 | align-items: center;
92 | height: 100%;
93 | display: flex;
94 | flex-direction: column;
95 | gap: 5px;
96 | }
97 |
98 | .figure-img {
99 | max-width: 100%;
100 | height: auto;
101 | }
102 |
103 | .adjustable-font-size {
104 | font-size: calc(0.5rem + 2vw);
105 | }
106 |
--------------------------------------------------------------------------------
/llava/mm_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from io import BytesIO
3 | import base64
4 |
5 | import torch
6 | from transformers import StoppingCriteria
7 | from llava.constants import IMAGE_TOKEN_INDEX
8 |
9 |
10 | def load_image_from_base64(image):
11 | return Image.open(BytesIO(base64.b64decode(image)))
12 |
13 |
14 | def expand2square(pil_img, background_color):
15 | width, height = pil_img.size
16 | if width == height:
17 | return pil_img
18 | elif width > height:
19 | result = Image.new(pil_img.mode, (width, width), background_color)
20 | result.paste(pil_img, (0, (width - height) // 2))
21 | return result
22 | else:
23 | result = Image.new(pil_img.mode, (height, height), background_color)
24 | result.paste(pil_img, ((height - width) // 2, 0))
25 | return result
26 |
27 |
28 | def process_images(images, image_processor, model_cfg):
29 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
30 | new_images = []
31 | if image_aspect_ratio == 'pad':
32 | for image in images:
33 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
34 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
35 | new_images.append(image)
36 | else:
37 | return image_processor(images, return_tensors='pt')['pixel_values']
38 | if all(x.shape == new_images[0].shape for x in new_images):
39 | new_images = torch.stack(new_images, dim=0)
40 | return new_images
41 |
42 |
43 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
44 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
45 |
46 | def insert_separator(X, sep):
47 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
48 |
49 | input_ids = []
50 | offset = 0
51 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
52 | offset = 1
53 | input_ids.append(prompt_chunks[0][0])
54 |
55 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
56 | input_ids.extend(x[offset:])
57 |
58 | if return_tensors is not None:
59 | if return_tensors == 'pt':
60 | return torch.tensor(input_ids, dtype=torch.long)
61 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
62 | return input_ids
63 |
64 |
65 | def get_model_name_from_path(model_path):
66 | model_path = model_path.strip("/")
67 | model_paths = model_path.split("/")
68 | if model_paths[-1].startswith('checkpoint-'):
69 | return model_paths[-2] + "_" + model_paths[-1]
70 | else:
71 | return model_paths[-1]
72 |
73 | class KeywordsStoppingCriteria(StoppingCriteria):
74 | def __init__(self, keywords, tokenizer, input_ids):
75 | self.keywords = keywords
76 | self.keyword_ids = []
77 | self.max_keyword_len = 0
78 | for keyword in keywords:
79 | cur_keyword_ids = tokenizer(keyword).input_ids
80 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
81 | cur_keyword_ids = cur_keyword_ids[1:]
82 | if len(cur_keyword_ids) > self.max_keyword_len:
83 | self.max_keyword_len = len(cur_keyword_ids)
84 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
85 | self.tokenizer = tokenizer
86 | self.start_len = input_ids.shape[1]
87 |
88 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
89 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
90 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
91 | for keyword_id in self.keyword_ids:
92 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
93 | return True
94 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
95 | for keyword in self.keywords:
96 | if keyword in outputs:
97 | return True
98 | return False
99 |
100 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
101 | outputs = []
102 | for i in range(output_ids.shape[0]):
103 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
104 | return all(outputs)
105 |
--------------------------------------------------------------------------------
/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
2 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
3 |
--------------------------------------------------------------------------------
/llava/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava import LlavaLlamaForCausalLM
11 |
12 |
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32 | bparam = base.state_dict()[name]
33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34 |
35 | print("Saving target model")
36 | delta.save_pretrained(target_model_path)
37 | delta_tokenizer.save_pretrained(target_model_path)
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--base-model-path", type=str, required=True)
43 | parser.add_argument("--target-model-path", type=str, required=True)
44 | parser.add_argument("--delta-path", type=str, required=True)
45 |
46 | args = parser.parse_args()
47 |
48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
49 |
--------------------------------------------------------------------------------
/llava/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | from llava.model import *
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def consolidate_ckpt(src_path, dst_path):
14 | print("Loading model")
15 | auto_upgrade(src_path)
16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18 | src_model.save_pretrained(dst_path)
19 | src_tokenizer.save_pretrained(dst_path)
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--src", type=str, required=True)
25 | parser.add_argument("--dst", type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 | consolidate_ckpt(args.src, args.dst)
30 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from transformers import AutoConfig, AutoModelForCausalLM, \
22 | LlamaConfig, LlamaModel, LlamaForCausalLM
23 |
24 | from transformers.modeling_outputs import CausalLMOutputWithPast
25 |
26 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
27 |
28 |
29 | class LlavaConfig(LlamaConfig):
30 | model_type = "llava"
31 |
32 |
33 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
34 | config_class = LlavaConfig
35 |
36 | def __init__(self, config: LlamaConfig):
37 | super(LlavaLlamaModel, self).__init__(config)
38 |
39 |
40 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
41 | config_class = LlavaConfig
42 |
43 | def __init__(self, config):
44 | super(LlamaForCausalLM, self).__init__(config)
45 | self.model = LlavaLlamaModel(config)
46 | self.pretraining_tp = config.pretraining_tp
47 | self.vocab_size = config.vocab_size
48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49 |
50 | # Initialize weights and apply final processing
51 | self.post_init()
52 |
53 | def get_model(self):
54 | return self.model
55 |
56 | def forward(
57 | self,
58 | input_ids: torch.LongTensor = None,
59 | attention_mask: Optional[torch.Tensor] = None,
60 | position_ids: Optional[torch.LongTensor] = None,
61 | past_key_values: Optional[List[torch.FloatTensor]] = None,
62 | inputs_embeds: Optional[torch.FloatTensor] = None,
63 | labels: Optional[torch.LongTensor] = None,
64 | use_cache: Optional[bool] = None,
65 | output_attentions: Optional[bool] = None,
66 | output_hidden_states: Optional[bool] = None,
67 | images: Optional[torch.FloatTensor] = None,
68 | return_dict: Optional[bool] = None,
69 | ) -> Union[Tuple, CausalLMOutputWithPast]:
70 |
71 | if inputs_embeds is None:
72 | (
73 | input_ids,
74 | position_ids,
75 | attention_mask,
76 | past_key_values,
77 | inputs_embeds,
78 | labels
79 | ) = self.prepare_inputs_labels_for_multimodal(
80 | input_ids,
81 | position_ids,
82 | attention_mask,
83 | past_key_values,
84 | labels,
85 | images
86 | )
87 |
88 | return super().forward(
89 | input_ids=input_ids,
90 | attention_mask=attention_mask,
91 | position_ids=position_ids,
92 | past_key_values=past_key_values,
93 | inputs_embeds=inputs_embeds,
94 | labels=labels,
95 | use_cache=use_cache,
96 | output_attentions=output_attentions,
97 | output_hidden_states=output_hidden_states,
98 | return_dict=return_dict
99 | )
100 |
101 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
102 | images = kwargs.pop("images", None)
103 | _inputs = super().prepare_inputs_for_generation(
104 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
105 | )
106 | if images is not None:
107 | _inputs['images'] = images
108 | return _inputs
109 |
110 | AutoConfig.register("llava", LlavaConfig)
111 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
112 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_mpt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple
17 | import warnings
18 |
19 | import torch
20 | import torch.nn.functional as F
21 | import math
22 |
23 | from transformers import AutoConfig, AutoModelForCausalLM
24 | from transformers.modeling_outputs import CausalLMOutputWithPast
25 |
26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel
27 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28 |
29 |
30 | class LlavaMPTConfig(MPTConfig):
31 | model_type = "llava_mpt"
32 |
33 |
34 | class LlavaMPTModel(LlavaMetaModel, MPTModel):
35 | config_class = LlavaMPTConfig
36 |
37 | def __init__(self, config: MPTConfig):
38 | config.hidden_size = config.d_model
39 | super(LlavaMPTModel, self).__init__(config)
40 |
41 | def embed_tokens(self, x):
42 | return self.wte(x)
43 |
44 |
45 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM):
46 | config_class = LlavaMPTConfig
47 | supports_gradient_checkpointing = True
48 |
49 | def __init__(self, config):
50 | super(MPTForCausalLM, self).__init__(config)
51 |
52 | if not config.tie_word_embeddings:
53 | raise ValueError('MPTForCausalLM only supports tied word embeddings')
54 | self.transformer = LlavaMPTModel(config)
55 | self.logit_scale = None
56 | if config.logit_scale is not None:
57 | logit_scale = config.logit_scale
58 | if isinstance(logit_scale, str):
59 | if logit_scale == 'inv_sqrt_d_model':
60 | logit_scale = 1 / math.sqrt(config.d_model)
61 | else:
62 | raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
63 | self.logit_scale = logit_scale
64 |
65 | def get_model(self):
66 | return self.transformer
67 |
68 | def _set_gradient_checkpointing(self, module, value=False):
69 | if isinstance(module, LlavaMPTModel):
70 | module.gradient_checkpointing = value
71 |
72 | def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None):
73 | return_dict = return_dict if return_dict is not None else self.config.return_dict
74 | use_cache = use_cache if use_cache is not None else self.config.use_cache
75 |
76 | input_ids, _, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, None, attention_mask, past_key_values, labels, images)
77 | outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
78 | # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338
79 | logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
80 | if self.logit_scale is not None:
81 | if self.logit_scale == 0:
82 | warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
83 | logits *= self.logit_scale
84 | loss = None
85 | if labels is not None:
86 | labels = torch.roll(labels, shifts=-1)
87 | labels[:, -1] = -100
88 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
89 | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
90 |
91 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
92 | if inputs_embeds is not None:
93 | raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
94 | attention_mask = kwargs['attention_mask'].bool()
95 | if attention_mask[:, -1].sum() != attention_mask.shape[0]:
96 | raise NotImplementedError('MPT does not support generation with right padding.')
97 | if self.transformer.attn_uses_sequence_id and self.training:
98 | sequence_id = torch.zeros_like(input_ids[:1])
99 | else:
100 | sequence_id = None
101 | if past_key_values is not None:
102 | input_ids = input_ids[:, -1].unsqueeze(-1)
103 | if self.transformer.prefix_lm:
104 | prefix_mask = torch.ones_like(attention_mask)
105 | if kwargs.get('use_cache') == False:
106 | raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
107 | else:
108 | prefix_mask = None
109 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)}
110 |
111 |
112 | AutoConfig.register("llava_mpt", LlavaMPTConfig)
113 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM)
114 |
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/adapt_tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
4 | NUM_SENTINEL_TOKENS: int = 100
5 |
6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
7 | """Adds sentinel tokens and padding token (if missing).
8 |
9 | Expands the tokenizer vocabulary to include sentinel tokens
10 | used in mixture-of-denoiser tasks as well as a padding token.
11 |
12 | All added tokens are added as special tokens. No tokens are
13 | added if sentinel tokens and padding token already exist.
14 | """
15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)]
16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
17 | if tokenizer.pad_token is None:
18 | tokenizer.add_tokens('', special_tokens=True)
19 | tokenizer.pad_token = ''
20 | assert tokenizer.pad_token_id is not None
21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)])
22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
23 | tokenizer.sentinel_token_ids = _sentinel_token_ids
24 |
25 | class AutoTokenizerForMOD(AutoTokenizer):
26 | """AutoTokenizer + Adaptation for MOD.
27 |
28 | A simple wrapper around AutoTokenizer to make instantiating
29 | an MOD-adapted tokenizer a bit easier.
30 |
31 | MOD-adapted tokenizers have sentinel tokens (e.g., ),
32 | a padding token, and a property to get the token ids of the
33 | sentinel tokens.
34 | """
35 |
36 | @classmethod
37 | def from_pretrained(cls, *args, **kwargs):
38 | """See `AutoTokenizer.from_pretrained` docstring."""
39 | tokenizer = super().from_pretrained(*args, **kwargs)
40 | adapt_tokenizer_for_denoising(tokenizer)
41 | return tokenizer
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/blocks.py:
--------------------------------------------------------------------------------
1 | """GPT Blocks used for the GPT Model."""
2 | from typing import Dict, Optional, Tuple
3 | import torch
4 | import torch.nn as nn
5 | from .attention import ATTN_CLASS_REGISTRY
6 | from .norm import NORM_CLASS_REGISTRY
7 |
8 | class MPTMLP(nn.Module):
9 |
10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
11 | super().__init__()
12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13 | self.act = nn.GELU(approximate='none')
14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
15 | self.down_proj._is_residual = True
16 |
17 | def forward(self, x):
18 | return self.down_proj(self.act(self.up_proj(x)))
19 |
20 | class MPTBlock(nn.Module):
21 |
22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
23 | del kwargs
24 | super().__init__()
25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27 | self.norm_1 = norm_class(d_model, device=device)
28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
29 | self.norm_2 = norm_class(d_model, device=device)
30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33 |
34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35 | a = self.norm_1(x)
36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37 | x = x + self.resid_attn_dropout(b)
38 | m = self.norm_2(x)
39 | n = self.ffn(m)
40 | x = x + self.resid_ffn_dropout(n)
41 | return (x, attn_weights, past_key_value)
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/custom_embedding.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch import Tensor
5 |
6 | class SharedEmbedding(nn.Embedding):
7 |
8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor:
9 | if unembed:
10 | return F.linear(input, self.weight)
11 | return super().forward(input)
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/meta_init_context.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | import torch
3 | import torch.nn as nn
4 |
5 | @contextmanager
6 | def init_empty_weights(include_buffers: bool=False):
7 | """Meta initialization context manager.
8 |
9 | A context manager under which models are initialized with all parameters
10 | on the meta device, therefore creating an empty model. Useful when just
11 | initializing the model would blow the available RAM.
12 |
13 | Args:
14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or
15 | not to also put all buffers on the meta device while initializing.
16 |
17 | Example:
18 | ```python
19 | import torch.nn as nn
20 |
21 | # Initialize a model with 100 billions parameters in no time and without using any RAM.
22 | with init_empty_weights():
23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
24 | ```
25 |
26 |
27 |
28 | Any model created under this context manager has no weights. As such you can't do something like
29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
30 |
31 |
32 | """
33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f:
34 | yield f
35 |
36 | @contextmanager
37 | def init_on_device(device: torch.device, include_buffers: bool=False):
38 | """Device initialization context manager.
39 |
40 | A context manager under which models are initialized with all parameters
41 | on the specified device.
42 |
43 | Args:
44 | device (`torch.device`): Device to initialize all parameters on.
45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or
46 | not to also put all buffers on the meta device while initializing.
47 |
48 | Example:
49 | ```python
50 | import torch.nn as nn
51 |
52 | with init_on_device(device=torch.device("cuda")):
53 | tst = nn.Liner(100, 100) # on `cuda` device
54 | ```
55 | """
56 | old_register_parameter = nn.Module.register_parameter
57 | if include_buffers:
58 | old_register_buffer = nn.Module.register_buffer
59 |
60 | def register_empty_parameter(module, name, param):
61 | old_register_parameter(module, name, param)
62 | if param is not None:
63 | param_cls = type(module._parameters[name])
64 | kwargs = module._parameters[name].__dict__
65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
66 |
67 | def register_empty_buffer(module, name, buffer):
68 | old_register_buffer(module, name, buffer)
69 | if buffer is not None:
70 | module._buffers[name] = module._buffers[name].to(device)
71 | if include_buffers:
72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
73 | else:
74 | tensor_constructors_to_patch = {}
75 |
76 | def patch_tensor_constructor(fn):
77 |
78 | def wrapper(*args, **kwargs):
79 | kwargs['device'] = device
80 | return fn(*args, **kwargs)
81 | return wrapper
82 | try:
83 | nn.Module.register_parameter = register_empty_parameter
84 | if include_buffers:
85 | nn.Module.register_buffer = register_empty_buffer
86 | for torch_function_name in tensor_constructors_to_patch.keys():
87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
88 | yield
89 | finally:
90 | nn.Module.register_parameter = old_register_parameter
91 | if include_buffers:
92 | nn.Module.register_buffer = old_register_buffer
93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
94 | setattr(torch, torch_function_name, old_torch_function)
--------------------------------------------------------------------------------
/llava/model/language_model/mpt/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def _cast_if_autocast_enabled(tensor):
4 | if torch.is_autocast_enabled():
5 | if tensor.device.type == 'cuda':
6 | dtype = torch.get_autocast_gpu_dtype()
7 | elif tensor.device.type == 'cpu':
8 | dtype = torch.get_autocast_cpu_dtype()
9 | else:
10 | raise NotImplementedError()
11 | return tensor.to(dtype=dtype)
12 | return tensor
13 |
14 | class LPLayerNorm(torch.nn.LayerNorm):
15 |
16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
18 |
19 | def forward(self, x):
20 | module_device = x.device
21 | downcast_x = _cast_if_autocast_enabled(x)
22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
24 | with torch.autocast(enabled=False, device_type=module_device.type):
25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26 |
27 | def rms_norm(x, weight=None, eps=1e-05):
28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29 | if weight is not None:
30 | return output * weight
31 | return output
32 |
33 | class RMSNorm(torch.nn.Module):
34 |
35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
36 | super().__init__()
37 | self.eps = eps
38 | if weight:
39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
40 | else:
41 | self.register_parameter('weight', None)
42 |
43 | def forward(self, x):
44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
45 |
46 | class LPRMSNorm(RMSNorm):
47 |
48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
50 |
51 | def forward(self, x):
52 | downcast_x = _cast_if_autocast_enabled(x)
53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
54 | with torch.autocast(enabled=False, device_type=x.device.type):
55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
--------------------------------------------------------------------------------
/llava/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31 | bparam = base.state_dict()[name]
32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower
3 |
4 |
5 | def build_vision_tower(vision_tower_cfg, **kwargs):
6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7 | is_absolute_path_exists = os.path.exists(vision_tower)
8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"):
9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
10 |
11 | raise ValueError(f'Unknown vision tower: {vision_tower}')
12 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5 |
6 |
7 | class CLIPVisionTower(nn.Module):
8 | def __init__(self, vision_tower, args, delay_load=False):
9 | super().__init__()
10 |
11 | self.is_loaded = False
12 |
13 | self.vision_tower_name = vision_tower
14 | self.select_layer = args.mm_vision_select_layer
15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16 |
17 | if not delay_load:
18 | self.load_model()
19 | else:
20 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
21 |
22 | def load_model(self):
23 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
24 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
25 | self.vision_tower.requires_grad_(False)
26 |
27 | self.is_loaded = True
28 |
29 | def feature_select(self, image_forward_outs):
30 | image_features = image_forward_outs.hidden_states[self.select_layer]
31 | if self.select_feature == 'patch':
32 | image_features = image_features[:, 1:]
33 | elif self.select_feature == 'cls_patch':
34 | image_features = image_features
35 | else:
36 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
37 | return image_features
38 |
39 | @torch.no_grad()
40 | def forward(self, images):
41 | if type(images) is list:
42 | image_features = []
43 | for image in images:
44 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
45 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
46 | image_features.append(image_feature)
47 | else:
48 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
49 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
50 |
51 | return image_features
52 |
53 | @property
54 | def dummy_feature(self):
55 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
56 |
57 | @property
58 | def dtype(self):
59 | return self.vision_tower.dtype
60 |
61 | @property
62 | def device(self):
63 | return self.vision_tower.device
64 |
65 | @property
66 | def config(self):
67 | if self.is_loaded:
68 | return self.vision_tower.config
69 | else:
70 | return self.cfg_only
71 |
72 | @property
73 | def hidden_size(self):
74 | return self.config.hidden_size
75 |
76 | @property
77 | def num_patches(self):
78 | return (self.config.image_size // self.config.patch_size) ** 2
79 |
--------------------------------------------------------------------------------
/llava/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 | def forward(self, x):
29 | x = self.pre_norm(x)
30 | return x + self.proj(x)
31 |
32 |
33 | def build_vision_projector(config, delay_load=False, **kwargs):
34 | projector_type = getattr(config, 'mm_projector_type', 'linear')
35 |
36 | if projector_type == 'linear':
37 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
38 |
39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40 | if mlp_gelu_match:
41 | mlp_depth = int(mlp_gelu_match.group(1))
42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43 | for _ in range(1, mlp_depth):
44 | modules.append(nn.GELU())
45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46 | return nn.Sequential(*modules)
47 |
48 | if projector_type == 'identity':
49 | return IdentityMap()
50 |
51 | raise ValueError(f'Unknown projector type: {projector_type}')
52 |
--------------------------------------------------------------------------------
/llava/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if 'llava' in config and 'llava' not in cfg.model_type:
7 | assert cfg.model_type == 'llama'
8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "llava")
15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
21 |
--------------------------------------------------------------------------------
/llava/serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/llava/serve/__init__.py
--------------------------------------------------------------------------------
/llava/serve/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5 | from llava.conversation import conv_templates, SeparatorStyle
6 | from llava.model.builder import load_pretrained_model
7 | from llava.utils import disable_torch_init
8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
9 |
10 | from PIL import Image
11 |
12 | import requests
13 | from PIL import Image
14 | from io import BytesIO
15 | from transformers import TextStreamer
16 |
17 |
18 | def load_image(image_file):
19 | if image_file.startswith('http://') or image_file.startswith('https://'):
20 | response = requests.get(image_file)
21 | image = Image.open(BytesIO(response.content)).convert('RGB')
22 | else:
23 | image = Image.open(image_file).convert('RGB')
24 | return image
25 |
26 |
27 | def main(args):
28 | # Model
29 | disable_torch_init()
30 |
31 | model_name = get_model_name_from_path(args.model_path)
32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
33 |
34 | if 'llama-2' in model_name.lower():
35 | conv_mode = "llava_llama_2"
36 | elif "v1" in model_name.lower():
37 | conv_mode = "llava_v1"
38 | elif "mpt" in model_name.lower():
39 | conv_mode = "mpt"
40 | else:
41 | conv_mode = "llava_v0"
42 |
43 | if args.conv_mode is not None and conv_mode != args.conv_mode:
44 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
45 | else:
46 | args.conv_mode = conv_mode
47 |
48 | conv = conv_templates[args.conv_mode].copy()
49 | if "mpt" in model_name.lower():
50 | roles = ('user', 'assistant')
51 | else:
52 | roles = conv.roles
53 |
54 | image = load_image(args.image_file)
55 | # Similar operation in model_worker.py
56 | image_tensor = process_images([image], image_processor, model.config)
57 | if type(image_tensor) is list:
58 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
59 | else:
60 | image_tensor = image_tensor.to(model.device, dtype=torch.float16)
61 |
62 | while True:
63 | try:
64 | inp = input(f"{roles[0]}: ")
65 | except EOFError:
66 | inp = ""
67 | if not inp:
68 | print("exit...")
69 | break
70 |
71 | print(f"{roles[1]}: ", end="")
72 |
73 | if image is not None:
74 | # first message
75 | if model.config.mm_use_im_start_end:
76 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
77 | else:
78 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
79 | conv.append_message(conv.roles[0], inp)
80 | image = None
81 | else:
82 | # later messages
83 | conv.append_message(conv.roles[0], inp)
84 | conv.append_message(conv.roles[1], None)
85 | prompt = conv.get_prompt()
86 |
87 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
88 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
89 | keywords = [stop_str]
90 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
91 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
92 |
93 | with torch.inference_mode():
94 | output_ids = model.generate(
95 | input_ids,
96 | images=image_tensor,
97 | do_sample=True if args.temperature > 0 else False,
98 | temperature=args.temperature,
99 | max_new_tokens=args.max_new_tokens,
100 | streamer=streamer,
101 | use_cache=True,
102 | stopping_criteria=[stopping_criteria])
103 |
104 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
105 | conv.messages[-1][-1] = outputs
106 |
107 | if args.debug:
108 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
109 |
110 |
111 | if __name__ == "__main__":
112 | parser = argparse.ArgumentParser()
113 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
114 | parser.add_argument("--model-base", type=str, default=None)
115 | parser.add_argument("--image-file", type=str, required=True)
116 | parser.add_argument("--device", type=str, default="cuda")
117 | parser.add_argument("--conv-mode", type=str, default=None)
118 | parser.add_argument("--temperature", type=float, default=0.2)
119 | parser.add_argument("--max-new-tokens", type=int, default=512)
120 | parser.add_argument("--load-8bit", action="store_true")
121 | parser.add_argument("--load-4bit", action="store_true")
122 | parser.add_argument("--debug", action="store_true")
123 | args = parser.parse_args()
124 | main(args)
125 |
--------------------------------------------------------------------------------
/llava/serve/examples/extreme_ironing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/llava/serve/examples/extreme_ironing.jpg
--------------------------------------------------------------------------------
/llava/serve/examples/waterview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/prometheus-eval/prometheus-vision/fe668b5e7e120e5bea8a3821eeb468ddf719c23c/llava/serve/examples/waterview.jpg
--------------------------------------------------------------------------------
/llava/serve/register_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | Manually register workers.
3 |
4 | Usage:
5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6 | """
7 |
8 | import argparse
9 |
10 | import requests
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--controller-address", type=str)
15 | parser.add_argument("--worker-name", type=str)
16 | parser.add_argument("--check-heart-beat", action="store_true")
17 | args = parser.parse_args()
18 |
19 | url = args.controller_address + "/register_worker"
20 | data = {
21 | "worker_name": args.worker_name,
22 | "check_heart_beat": args.check_heart_beat,
23 | "worker_status": None,
24 | }
25 | r = requests.post(url, json=data)
26 | assert r.status_code == 200
27 |
--------------------------------------------------------------------------------
/llava/serve/test_message.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | import requests
5 |
6 | from llava.conversation import default_conversation
7 |
8 |
9 | def main():
10 | if args.worker_address:
11 | worker_addr = args.worker_address
12 | else:
13 | controller_addr = args.controller_address
14 | ret = requests.post(controller_addr + "/refresh_all_workers")
15 | ret = requests.post(controller_addr + "/list_models")
16 | models = ret.json()["models"]
17 | models.sort()
18 | print(f"Models: {models}")
19 |
20 | ret = requests.post(controller_addr + "/get_worker_address",
21 | json={"model": args.model_name})
22 | worker_addr = ret.json()["address"]
23 | print(f"worker_addr: {worker_addr}")
24 |
25 | if worker_addr == "":
26 | return
27 |
28 | conv = default_conversation.copy()
29 | conv.append_message(conv.roles[0], args.message)
30 | prompt = conv.get_prompt()
31 |
32 | headers = {"User-Agent": "LLaVA Client"}
33 | pload = {
34 | "model": args.model_name,
35 | "prompt": prompt,
36 | "max_new_tokens": args.max_new_tokens,
37 | "temperature": 0.7,
38 | "stop": conv.sep,
39 | }
40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41 | json=pload, stream=True)
42 |
43 | print(prompt.replace(conv.sep, "\n"), end="")
44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45 | if chunk:
46 | data = json.loads(chunk.decode("utf-8"))
47 | output = data["text"].split(conv.sep)[-1]
48 | print(output, end="\r")
49 | print("")
50 |
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55 | parser.add_argument("--worker-address", type=str)
56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57 | parser.add_argument("--max-new-tokens", type=int, default=32)
58 | parser.add_argument("--message", type=str, default=
59 | "Tell me a story with more than 1000 words.")
60 | args = parser.parse_args()
61 |
62 | main()
63 |
--------------------------------------------------------------------------------
/llava/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import warnings
3 |
4 | import torch
5 |
6 | import transformers
7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8 |
9 | try:
10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11 | except ImportError:
12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 |
16 | def forward(
17 | self,
18 | hidden_states: torch.Tensor,
19 | attention_mask: Optional[torch.Tensor] = None,
20 | position_ids: Optional[torch.Tensor] = None,
21 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
22 | output_attentions: bool = False,
23 | use_cache: bool = False,
24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
25 | if output_attentions:
26 | warnings.warn(
27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
28 | )
29 |
30 | bsz, q_len, _ = hidden_states.size()
31 |
32 | query_states = (
33 | self.q_proj(hidden_states)
34 | .view(bsz, q_len, self.num_heads, self.head_dim)
35 | .transpose(1, 2)
36 | )
37 | key_states = (
38 | self.k_proj(hidden_states)
39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
40 | .transpose(1, 2)
41 | )
42 | value_states = (
43 | self.v_proj(hidden_states)
44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
45 | .transpose(1, 2)
46 | ) # shape: (b, num_heads, s, head_dim)
47 |
48 | kv_seq_len = key_states.shape[-2]
49 | if past_key_value is not None:
50 | kv_seq_len += past_key_value[0].shape[-2]
51 |
52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
53 | query_states, key_states = apply_rotary_pos_emb(
54 | query_states, key_states, cos, sin, position_ids
55 | )
56 |
57 | if past_key_value is not None:
58 | # reuse k, v
59 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
60 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
61 |
62 | past_key_value = (key_states, value_states) if use_cache else None
63 |
64 | # repeat k/v heads if n_kv_heads < n_heads
65 | key_states = repeat_kv(key_states, self.num_key_value_groups)
66 | value_states = repeat_kv(value_states, self.num_key_value_groups)
67 |
68 | # Transform the data into the format required by flash attention
69 | qkv = torch.stack([query_states, key_states, value_states], dim=2)
70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
71 | key_padding_mask = attention_mask
72 |
73 | if key_padding_mask is None:
74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
75 | cu_q_lens = torch.arange(
76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
77 | )
78 | max_s = q_len
79 | output = flash_attn_unpadded_qkvpacked_func(
80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
81 | )
82 | output = output.view(bsz, q_len, -1)
83 | else:
84 | qkv = qkv.reshape(bsz, q_len, -1)
85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
87 | output_unpad = flash_attn_unpadded_qkvpacked_func(
88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
89 | )
90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
91 | output = pad_input(output_unpad, indices, bsz, q_len)
92 |
93 | return self.o_proj(output), None, past_key_value
94 |
95 |
96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
97 | # requires the attention mask to be the same as the key_padding_mask
98 | def _prepare_decoder_attention_mask(
99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
100 | ):
101 | # [bsz, seq_len]
102 | return attention_mask
103 |
104 |
105 | def replace_llama_attn_with_flash_attn():
106 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
107 | if cuda_major < 8:
108 | warnings.warn(
109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
111 | )
112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
113 | _prepare_decoder_attention_mask
114 | )
115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
116 |
--------------------------------------------------------------------------------
/llava/train/llama_xformers_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | """
2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
3 | """
4 |
5 | import logging
6 | import math
7 | from typing import Optional, Tuple
8 |
9 | import torch
10 | import transformers.models.llama.modeling_llama
11 | from torch import nn
12 |
13 | try:
14 | import xformers.ops
15 | except ImportError:
16 | logging.error("xformers not found! Please install it before trying to use it.")
17 |
18 |
19 | def replace_llama_attn_with_xformers_attn():
20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
21 |
22 |
23 | def xformers_forward(
24 | self,
25 | hidden_states: torch.Tensor,
26 | attention_mask: Optional[torch.Tensor] = None,
27 | position_ids: Optional[torch.LongTensor] = None,
28 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
29 | output_attentions: bool = False,
30 | use_cache: bool = False,
31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
32 | # pylint: disable=duplicate-code
33 | bsz, q_len, _ = hidden_states.size()
34 |
35 | query_states = (
36 | self.q_proj(hidden_states)
37 | .view(bsz, q_len, self.num_heads, self.head_dim)
38 | .transpose(1, 2)
39 | )
40 | key_states = (
41 | self.k_proj(hidden_states)
42 | .view(bsz, q_len, self.num_heads, self.head_dim)
43 | .transpose(1, 2)
44 | )
45 | value_states = (
46 | self.v_proj(hidden_states)
47 | .view(bsz, q_len, self.num_heads, self.head_dim)
48 | .transpose(1, 2)
49 | )
50 |
51 | kv_seq_len = key_states.shape[-2]
52 | if past_key_value is not None:
53 | kv_seq_len += past_key_value[0].shape[-2]
54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55 | (
56 | query_states,
57 | key_states,
58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
59 | query_states, key_states, cos, sin, position_ids
60 | )
61 | # [bsz, nh, t, hd]
62 |
63 | if past_key_value is not None:
64 | # reuse k, v, self_attention
65 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
66 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
67 |
68 | past_key_value = (key_states, value_states) if use_cache else None
69 |
70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix
71 | if not output_attentions:
72 | query_states = query_states.transpose(1, 2)
73 | key_states = key_states.transpose(1, 2)
74 | value_states = value_states.transpose(1, 2)
75 |
76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
79 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
80 | attn_output = xformers.ops.memory_efficient_attention(
81 | query_states, key_states, value_states, attn_bias=None
82 | )
83 | else:
84 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
85 | attn_output = xformers.ops.memory_efficient_attention(
86 | query_states,
87 | key_states,
88 | value_states,
89 | attn_bias=xformers.ops.LowerTriangularMask(),
90 | )
91 | attn_weights = None
92 | else:
93 | attn_weights = torch.matmul(
94 | query_states, key_states.transpose(2, 3)
95 | ) / math.sqrt(self.head_dim)
96 |
97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
98 | raise ValueError(
99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
100 | f" {attn_weights.size()}"
101 | )
102 |
103 | if attention_mask is not None:
104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
105 | raise ValueError(
106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
107 | )
108 | attn_weights = attn_weights + attention_mask
109 | attn_weights = torch.max(
110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
111 | )
112 |
113 | # upcast attention to fp32
114 | attn_weights = nn.functional.softmax(
115 | attn_weights, dim=-1, dtype=torch.float32
116 | ).to(query_states.dtype)
117 | attn_output = torch.matmul(attn_weights, value_states)
118 |
119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
120 | raise ValueError(
121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
122 | f" {attn_output.size()}"
123 | )
124 |
125 | attn_output = attn_output.transpose(1, 2)
126 |
127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
128 | attn_output = self.o_proj(attn_output)
129 | return attn_output, attn_weights, past_key_value
130 |
--------------------------------------------------------------------------------
/llava/train/train_mem.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4 |
5 | # Need to call this before importing transformers.
6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7 |
8 | replace_llama_attn_with_flash_attn()
9 |
10 | from llava.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/llava/train/train_xformers.py:
--------------------------------------------------------------------------------
1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
2 |
3 | # Need to call this before importing transformers.
4 | from llava.train.llama_xformers_attn_monkey_patch import (
5 | replace_llama_attn_with_xformers_attn,
6 | )
7 |
8 | replace_llama_attn_with_xformers_attn()
9 |
10 | from llava.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/llava/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from llava.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True, encoding='UTF-8')
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/mmmu_download.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | from tqdm import tqdm
3 | import os
4 |
5 | cate = ['Accounting', 'Agriculture', 'Architecture_and_Engineering', 'Art', 'Art_Theory', 'Basic_Medical_Science', 'Biology', 'Chemistry', 'Clinical_Medicine', 'Computer_Science', 'Design', 'Diagnostics_and_Laboratory_Medicine', 'Economics', 'Electronics', 'Energy_and_Power', 'Finance', 'Geography', 'History', 'Literature', 'Manage', 'Marketing', 'Materials', 'Math', 'Mechanical_Engineering', 'Music', 'Pharmacy', 'Physics', 'Psychology', 'Public_Health', 'Sociology']
6 | temp = {
7 | "Art": "art_and_design",
8 | "Design": "art_and_design",
9 | "Music": "art_and_design",
10 | "Art_Theory": "art_and_design",
11 | "Accounting": "business",
12 | "Economics": "business",
13 | "Finance": "business",
14 | "Manage": "business",
15 | "Marketing": "business",
16 | "Biology": "science",
17 | "Chemistry": "science",
18 | "Geography": "science",
19 | "Math": "science",
20 | "Physics": "science",
21 | "Basic_Medical_Science": "health_and_medicine",
22 | "Clinical_Medicine": "health_and_medicine",
23 | "Diagnostics_and_Laboratory_Medicine": "health_and_medicine",
24 | "Pharmacy": "health_and_medicine",
25 | "Public_Health": "health_and_medicine",
26 | "History": "humanities_and_social_sci",
27 | "Literature": "humanities_and_social_sci",
28 | "Psychology": "humanities_and_social_sci",
29 | "Sociology": "humanities_and_social_sci",
30 | "Agriculture": "tech_and_engineering",
31 | "Architecture_and_Engineering": "tech_and_engineering",
32 | "Computer_Science": "tech_and_engineering",
33 | "Electronics": "tech_and_engineering",
34 | "Energy_and_Power": "tech_and_engineering",
35 | "Materials": "tech_and_engineering",
36 | "Mechanical_Engineering": "tech_and_engineering"
37 |
38 | }
39 | ids = 0
40 | save_dir = "./mmmu"
41 | if not os.path.exists(save_dir):
42 | os.makedirs(save_dir)
43 | pattern = r"\['(.*?)'\]"
44 | for c in tqdm(cate):
45 | dataset = load_dataset("MMMU/MMMU", c)
46 | splits = ['dev', 'test', 'validation']
47 | for s in splits:
48 | images = dataset[s]['image_1']
49 | for img in images:
50 | path = f"{save_dir}/{ids}.png"
51 | img.save(path)
52 | ids += 1
53 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "llava"
7 | version = "1.1.3"
8 | description = "Towards GPT-4 like large language and visual assistant."
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "torch==2.0.1", "torchvision==0.15.2",
17 | "transformers==4.31.0", "tokenizers>=0.12.1,<0.14", "sentencepiece==0.1.99", "shortuuid",
18 | "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0",
19 | "pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
20 | "gradio==3.35.2", "gradio_client==0.2.9",
21 | "requests", "httpx==0.24.0", "uvicorn", "fastapi",
22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
23 | ]
24 |
25 | [project.optional-dependencies]
26 | train = ["deepspeed==0.9.5", "ninja", "wandb"]
27 |
28 | [project.urls]
29 | "Homepage" = "https://llava-vl.github.io"
30 | "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues"
31 |
32 | [tool.setuptools.packages.find]
33 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
34 |
35 | [tool.wheel]
36 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
37 |
--------------------------------------------------------------------------------