├── .DS_Store
├── .gitignore
├── LICENSE
├── README.md
├── assets
├── .DS_Store
├── audios
│ └── bird_audio.wav
└── images
│ ├── CT.png
│ ├── bird_image.jpg
│ └── waterview.jpg
├── header.py
├── llava
├── .DS_Store
├── __init__.py
├── constants.py
├── conversation.py
├── data
│ ├── __init__.py
│ ├── alpaca-converter.py
│ ├── clean_sharegpt.py
│ ├── inspect.py
│ ├── optional_clean.py
│ ├── pretty_json.py
│ └── split_long_conversation.py
├── eval
│ ├── .DS_Store
│ ├── eval_gpt_review.py
│ ├── eval_gpt_review_visual.py
│ ├── eval_science_qa.py
│ ├── eval_science_qa_gpt4.py
│ ├── eval_science_qa_gpt4_requery.py
│ ├── generate_webpage_data_from_table.py
│ ├── model_qa.py
│ ├── model_vqa.py
│ ├── model_vqa_science.py
│ ├── qa_baseline_gpt35.py
│ ├── run_llava.py
│ ├── summarize_gpt_review.py
│ ├── table
│ │ ├── answer
│ │ │ ├── answer_alpaca-13b.jsonl
│ │ │ ├── answer_bard.jsonl
│ │ │ ├── answer_gpt35.jsonl
│ │ │ ├── answer_llama-13b.jsonl
│ │ │ └── answer_vicuna-13b.jsonl
│ │ ├── caps_boxes_coco2014_val_80.jsonl
│ │ ├── model.jsonl
│ │ ├── prompt.jsonl
│ │ ├── question.jsonl
│ │ ├── results
│ │ │ └── test_sqa_llava_13b_v0.json
│ │ ├── review
│ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl
│ │ │ ├── review_bard_vicuna-13b.jsonl
│ │ │ ├── review_gpt35_vicuna-13b.jsonl
│ │ │ └── review_llama-13b_vicuna-13b.jsonl
│ │ ├── reviewer.jsonl
│ │ └── rule.json
│ └── webpage
│ │ ├── figures
│ │ ├── alpaca.png
│ │ ├── bard.jpg
│ │ ├── chatgpt.svg
│ │ ├── llama.jpg
│ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg
│ │ └── vicuna.jpeg
│ │ ├── index.html
│ │ ├── script.js
│ │ └── styles.css
├── llava_injection.py
├── model
│ ├── .DS_Store
│ ├── __init__.py
│ ├── apply_delta.py
│ ├── consolidate.py
│ ├── llava.py
│ ├── llava_mpt.py
│ ├── make_delta.py
│ ├── mpt
│ │ ├── adapt_tokenizer.py
│ │ ├── attention.py
│ │ ├── blocks.py
│ │ ├── configuration_mpt.py
│ │ ├── hf_prefixlm_converter.py
│ │ ├── meta_init_context.py
│ │ ├── modeling_mpt.py
│ │ ├── norm.py
│ │ └── param_init_fns.py
│ └── utils.py
├── pyproject.toml
├── serve
│ ├── .DS_Store
│ ├── __init__.py
│ ├── cli.py
│ ├── controller.py
│ ├── examples
│ │ ├── CornellTech.png
│ │ ├── crying_boy.jpg
│ │ ├── extreme_ironing.jpg
│ │ ├── london.jpg
│ │ ├── math.jpg
│ │ ├── math.png
│ │ ├── poo.jpg
│ │ ├── porsche911.jpeg
│ │ ├── porsche911.png
│ │ ├── steak.jpg
│ │ ├── thief.jpg
│ │ └── waterview.jpg
│ ├── gateway
│ │ ├── README.md
│ │ └── nginx.conf
│ ├── gradio_css.py
│ ├── gradio_patch.py
│ ├── gradio_web_server.py
│ ├── model_worker.py
│ ├── register_worker.py
│ └── test_message.py
├── train
│ ├── llama_flash_attn_monkey_patch.py
│ ├── llava_trainer.py
│ ├── train.py
│ └── train_mem.py
└── utils.py
├── models
└── .gitkeep
├── original_images
├── CT.png
└── waterview.jpg
├── pandagpt
├── .DS_Store
├── .gitkeep
├── code
│ ├── assets
│ │ ├── audios
│ │ │ ├── bird_audio.wav
│ │ │ ├── car_audio.wav
│ │ │ └── dog_audio.wav
│ │ ├── images
│ │ │ ├── bird_image.jpg
│ │ │ ├── car_image.jpg
│ │ │ └── dog_image.jpg
│ │ ├── thermals
│ │ │ ├── 190662.jpg
│ │ │ └── 210009.jpg
│ │ └── videos
│ │ │ ├── a.mp4
│ │ │ └── world.mp4
│ ├── config
│ │ ├── __init__.py
│ │ ├── base.yaml
│ │ └── openllama_peft.yaml
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── samplers.py
│ │ └── sft_dataset.py
│ ├── dsconfig
│ │ └── openllama_peft_stage_1.json
│ ├── header.py
│ ├── model
│ │ ├── ImageBind
│ │ │ ├── CODE_OF_CONDUCT.md
│ │ │ ├── CONTRIBUTING.md
│ │ │ ├── LICENSE
│ │ │ ├── README.md
│ │ │ ├── __init__.py
│ │ │ ├── bpe
│ │ │ │ └── bpe_simple_vocab_16e6.txt.gz
│ │ │ ├── data.py
│ │ │ ├── model_card.md
│ │ │ ├── models
│ │ │ │ ├── __init__.py
│ │ │ │ ├── helpers.py
│ │ │ │ ├── imagebind_model.py
│ │ │ │ ├── multimodal_preprocessors.py
│ │ │ │ └── transformer.py
│ │ │ └── requirements.txt
│ │ ├── __init__.py
│ │ ├── agent.py
│ │ ├── modeling_llama.py
│ │ └── openllama.py
│ ├── scripts
│ │ └── train.sh
│ ├── train_sft.py
│ └── web_demo.py
├── data
│ └── empty.txt
├── pandagpt_injection.py
├── pretrained_ckpt
│ ├── README.md
│ ├── imagebind_ckpt
│ │ └── empty.txt
│ ├── pandagpt_ckpt
│ │ ├── 13b
│ │ │ └── empty.txt
│ │ └── 7b
│ │ │ └── empty.txt
│ └── vicuna_ckpt
│ │ ├── 13b_v0
│ │ └── empty.txt
│ │ └── 7b_v0
│ │ └── empty.txt
└── requirements.txt
├── result_audios
├── .gitkeep
├── bird_malicious.pt
├── bird_malicious.wav
├── panda-italy-baseline.png
└── panda-italy.png
├── result_images
├── .DS_Store
├── llava-baby-baseline.png
├── llava-crying-baby.png
├── llava-pirate.png
├── llava-potter.png
├── llava
│ ├── harrypotter_partial.png
│ ├── harrypotter_partial.pt
│ ├── perturb_full_X.jpg
│ ├── perturb_full_X.pt
│ ├── perturb_partial_X.jpg
│ └── perturb_partial_X.pt
├── panda-audio-phishing.png
└── pandagpt
│ ├── panda_cow_full.jpg
│ ├── panda_cow_full.pt
│ ├── panda_cow_partial.jpg
│ └── panda_cow_partial.pt
├── run_llava_injection.ipynb
└── run_pandagpt_injection.ipynb
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Eugene Bagdasaryan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/assets/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/assets/.DS_Store
--------------------------------------------------------------------------------
/assets/audios/bird_audio.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/assets/audios/bird_audio.wav
--------------------------------------------------------------------------------
/assets/images/CT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/assets/images/CT.png
--------------------------------------------------------------------------------
/assets/images/bird_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/assets/images/bird_image.jpg
--------------------------------------------------------------------------------
/assets/images/waterview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/assets/images/waterview.jpg
--------------------------------------------------------------------------------
/header.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import datetime
3 | import types
4 | import deepspeed
5 | from transformers.deepspeed import HfDeepSpeedConfig
6 | import transformers
7 | import numpy as np
8 | from collections import OrderedDict
9 | from torch.utils.data import Dataset, DataLoader
10 | from torch.nn.utils import clip_grad_norm_
11 | from torch.cuda.amp import autocast, GradScaler
12 | from torch.nn import DataParallel
13 | from torch.optim import lr_scheduler
14 | import torch.optim as optim
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from tqdm import tqdm
18 | import os
19 | import re
20 | import math
21 | import random
22 | import json
23 | import time
24 | import logging
25 | from copy import deepcopy
26 | import ipdb
27 | import argparse
28 | import data
29 | from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
30 | from torch.nn.utils.rnn import pad_sequence
31 | from peft import LoraConfig, TaskType, get_peft_model
32 |
33 | logging.getLogger("transformers").setLevel(logging.WARNING)
34 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
35 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
36 |
--------------------------------------------------------------------------------
/llava/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/.DS_Store
--------------------------------------------------------------------------------
/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/llava/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
--------------------------------------------------------------------------------
/llava/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/data/__init__.py
--------------------------------------------------------------------------------
/llava/data/alpaca-converter.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import pathlib
4 |
5 | # Prompt from stanford alpaca's training script
6 | PROMPT_DICT = {
7 | "prompt_input": (
8 | "Below is an instruction that describes a task, paired with an input that provides further context. "
9 | "Write a response that appropriately completes the request.\n\n"
10 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
11 | ),
12 | "prompt_no_input": (
13 | "Below is an instruction that describes a task. "
14 | "Write a response that appropriately completes the request.\n\n"
15 | "### Instruction:\n{instruction}\n\n### Response:"
16 | ),
17 | }
18 |
19 |
20 | def main(args):
21 | data_path = pathlib.Path(args.data_path)
22 | with data_path.open() as f:
23 | data = json.load(f)
24 |
25 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
26 | sources = [
27 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
28 | for example in data
29 | ]
30 | targets = [example['output'] for example in data]
31 |
32 | new_data = []
33 | cnt = 1
34 | for s, t in zip(sources, targets):
35 | new_data.append({
36 | 'id': str(cnt),
37 | 'conversations': [
38 | {
39 | 'from': 'human',
40 | 'value': s,
41 | },
42 | {
43 | 'from': 'gpt',
44 | 'value': t,
45 | }
46 | ]
47 | })
48 | cnt += 1
49 |
50 | json.dump(new_data, open(args.output_path, 'w'), indent=2)
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument('--data_path', type=str, default='alpaca-data.json')
55 | parser.add_argument('--output_path', type=str, default='alpaca-data-conversation.json')
56 | args = parser.parse_args()
57 | main(args)
58 |
59 |
--------------------------------------------------------------------------------
/llava/data/clean_sharegpt.py:
--------------------------------------------------------------------------------
1 | """
2 | - Convert html to markdown with basic data cleaning.
3 | - Deduplication.
4 |
5 | Usage:
6 | python3 -m fastchat.data.clean_sharegpt --in sharegpt_html.json --out sharegpt_clean.json
7 | """
8 | import argparse
9 | from concurrent.futures import ProcessPoolExecutor
10 | import json
11 | import logging
12 | import re
13 | from typing import Dict, Union
14 |
15 | import bs4
16 | import markdownify # == 0.11.6
17 | from tqdm import tqdm
18 |
19 |
20 | div_pattern = re.compile("
")
21 | span_pattern = re.compile("")
22 | code_lang_pattern = re.compile(
23 | "```\s*" + "(.*?)" + "(?:Copy code)+" + "(.+?)" + "\s*?```", re.DOTALL
24 | )
25 | code_lang_format = "```\g<1>\n\g<2>\n```"
26 | regenerate_pattern = re.compile("\d+ / \d+")
27 | copy_chars_pattern = re.compile("Copy\d+ chars / \d+ words")
28 | copy_code_pattern = re.compile("```(.*?)Copy code\s*```")
29 |
30 |
31 | def reformat_code(val: str) -> str:
32 | # Input code format is:
33 | # ```
34 | # $Copy code$
35 | #
36 | # ```
37 | # This function convert it into the correct markdown format
38 | return re.sub(code_lang_pattern, code_lang_format, val)
39 |
40 |
41 | def html_to_markdown(val: str) -> str:
42 | # Remove all . This is required to make intent work in code blocks.
43 | val = re.sub(div_pattern, "", val)
44 | # Remove all
. This is required to make underscores work in code blocks.
45 | val = re.sub(span_pattern, "", val)
46 | # Markdown to html
47 | val = markdownify.markdownify(val).strip()
48 | # Reformat code
49 | val = reformat_code(val)
50 |
51 | # Remove noisy "[number] / [number]" at the beginning
52 | noise = re.search(regenerate_pattern, val)
53 | if noise and noise.start() == 0:
54 | val = val[noise.end() :]
55 | # Remove noisy "Copy[number] chars / [number] words"
56 | val = re.sub(copy_chars_pattern, "", val)
57 | # Remove empty code block ```\nCopy code\n```
58 | val = re.sub(copy_code_pattern, "", val)
59 |
60 | # Strip
61 | val = val.replace("\n\n\n", "\n").strip()
62 |
63 | return val
64 |
65 |
66 | def contain_blocked_words(val: str) -> bool:
67 | blocked_words = ["openai", "chatgpt"]
68 | for w in blocked_words:
69 | if w in val.lower():
70 | return True
71 | return False
72 |
73 |
74 | def clean_html_one_sample(sample):
75 | roles = ["human", "gpt"]
76 |
77 | if len(sample["conversations"]) <= 1:
78 | return (sample, 1)
79 |
80 | # Adjust the offset for cases like https://sharegpt.com/c/VyaZlh4
81 | if sample["conversations"][0]["from"] != "human":
82 | sample["conversations"] = sample["conversations"][1:]
83 | if len(sample["conversations"]) <= 1:
84 | return (sample, 1)
85 |
86 | if sample["conversations"][-1]["from"] == "human":
87 | sample["conversations"] = sample["conversations"][:-1]
88 | if len(sample["conversations"]) <= 1:
89 | return (sample, 1)
90 |
91 | for i, c in enumerate(sample["conversations"]):
92 | if c["from"] != roles[i % 2]:
93 | return (sample, 2)
94 |
95 | if contain_blocked_words(c["value"]):
96 | return (sample, 3)
97 |
98 | try:
99 | new_val = html_to_markdown(c["value"])
100 | except (bs4.builder.ParserRejectedMarkup, AssertionError):
101 | return (sample, 4)
102 |
103 | c["value"] = new_val
104 |
105 | return (sample, 0)
106 |
107 |
108 | def clean_html_all(content, begin, end):
109 | """
110 | Clean the source html files.
111 | """
112 | cnt_skip = 0
113 | cnt_blocked_words = 0
114 | cnt_wrong_format = 0
115 | cnt_parser_error = 0
116 | cnt_too_short = 0
117 | cnt_id_duplication = 0
118 | cnt_value_duplication = 0
119 | cnt_tag = 0
120 |
121 | content = content[begin:end]
122 | processed = []
123 | with ProcessPoolExecutor() as executor:
124 | for result in tqdm(
125 | executor.map(clean_html_one_sample, content), total=len(content)
126 | ):
127 | processed.append(result)
128 |
129 | visited = {}
130 | new_content = []
131 | for sample, error_code in tqdm(processed):
132 | cid = sample["id"]
133 | skipped = True
134 |
135 | if error_code != 0:
136 | if error_code == 1:
137 | print(f"id {cid} is too short")
138 | cnt_too_short += 1
139 | elif error_code == 2:
140 | print(f"id {cid} has a wrong format")
141 | cnt_wrong_format += 1
142 | elif error_code == 3:
143 | print(f"id {cid} contains blocked words")
144 | cnt_blocked_words += 1
145 | elif error_code == 4:
146 | print(f"id {cid} contains parser errors")
147 | cnt_parser_error += 1
148 | else:
149 | raise ValueError(f"Invalid error_code: {error_code}")
150 | elif cid in visited:
151 | print(f"id {cid} is an id duplication of {visited[cid]}")
152 | cnt_id_duplication += 1
153 | elif (
154 | sample["conversations"][1]["value"],
155 | len(sample["conversations"]),
156 | ) in visited:
157 | key = (sample["conversations"][1]["value"], len(sample["conversations"]))
158 | print(f"id {cid} is a value duplication of {visited[key]}")
159 | cnt_value_duplication += 1
160 | else:
161 | key = (sample["conversations"][1]["value"], len(sample["conversations"]))
162 | visited[cid] = visited[key] = cid
163 | skipped = False
164 |
165 | if not skipped:
166 | new_content.append(sample)
167 | else:
168 | cnt_skip += 1
169 |
170 | print(
171 | f"total: {len(content)}, skip: {cnt_skip}, new: {len(new_content)}, "
172 | f"cnt_blocked_words: {cnt_blocked_words}, cnt_parser_error: {cnt_parser_error}, "
173 | f"cnt_wrong_format: {cnt_wrong_format}, "
174 | f"cnt_too_short: {cnt_too_short}, cnt_id_duplication: {cnt_id_duplication}, "
175 | f"cnt_value_duplication: {cnt_value_duplication}, "
176 | )
177 |
178 | return new_content
179 |
180 |
181 | def main(args):
182 | content = json.load(open(args["in_file"], "r"))
183 | content = clean_html_all(content, args["begin"], args["end"])
184 | json.dump(content, open(args["out_file"], "w"), indent=2)
185 |
186 |
187 | if __name__ == "__main__":
188 | parser = argparse.ArgumentParser()
189 | parser.add_argument("--in-file", type=str, required=True)
190 | parser.add_argument("--out-file", type=str, default="sharegpt_clean.json")
191 | parser.add_argument("--begin", type=int)
192 | parser.add_argument("--end", type=int)
193 | parser.add_argument("--debug", action="store_true")
194 | args = parser.parse_args()
195 | main(vars(args))
196 |
--------------------------------------------------------------------------------
/llava/data/inspect.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.data.inspect --in sharegpt_20230322_clean_lang_split.json
4 | """
5 | import argparse
6 | import json
7 |
8 | import tqdm
9 |
10 |
11 | if __name__ == "__main__":
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--in-file", type=str, required=True)
14 | parser.add_argument("--begin", type=int)
15 | args = parser.parse_args()
16 |
17 | content = json.load(open(args.in_file, "r"))
18 | for sample in tqdm.tqdm(content[args.begin:]):
19 | print(f"id: {sample['id']}")
20 | for conv in sample["conversations"]:
21 | print(conv["from"] + ": ")
22 | print(conv["value"])
23 | input()
24 |
--------------------------------------------------------------------------------
/llava/data/optional_clean.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.data.optional_clean --lang en --reduce-rep --in sharegpt_clean.json --out output.json
4 | python3 -m fastchat.data.optional_clean --skip-lang en --reduce-rep --in sharegpt_clean.json --out output.json
5 | """
6 | import argparse
7 | import json
8 | import re
9 |
10 | import polyglot
11 | from polyglot.detect import Detector
12 | import pycld2
13 | from tqdm import tqdm
14 |
15 |
16 | def skip(conv, args):
17 | # Remove certain languages
18 | if args.lang != "all" or args.skip_lang is not None:
19 | text = "\n".join([x["value"] for x in conv["conversations"]])
20 | try:
21 | lang_code = Detector(text).language.code
22 | except (pycld2.error, polyglot.detect.base.UnknownLanguage):
23 | lang_code = "unknown"
24 |
25 | if args.lang != "all" and lang_code != args.lang:
26 | return True
27 |
28 | if lang_code == args.skip_lang:
29 | return True
30 |
31 | # Remove repetitive numbers
32 | if args.reduce_rep:
33 | for sentence in conv["conversations"]:
34 | val = sentence["value"]
35 | sub = re.search(r"(\d)\1{8}", val)
36 | if sub is not None:
37 | return True
38 |
39 | return False
40 |
41 |
42 | if __name__ == "__main__":
43 | parser = argparse.ArgumentParser()
44 | parser.add_argument("--in-file", type=str, required=True)
45 | parser.add_argument("--out-file", type=str, default="")
46 | parser.add_argument("--lang", type=str, default="all",
47 | choices=["all", "en"])
48 | parser.add_argument("--skip-lang", type=str)
49 | # NOTE: Be careful about reduce_rep which may remove some good data.
50 | # For example, addresses could have long consecutive 0's
51 | parser.add_argument("--reduce-rep", action="store_true")
52 | args = parser.parse_args()
53 |
54 | in_file = args.in_file
55 | out_file = args.out_file
56 | lang = args.lang
57 | skip_lang = args.skip_lang
58 | reduce_rep = args.reduce_rep
59 | assert (lang == "all" or skip_lang is None)
60 |
61 | if out_file == "":
62 | out_file = "sharegpt_clean"
63 | if lang != "all":
64 | out_file += "_" + lang
65 | if skip_lang is not None:
66 | out_file += "_skip_" + skip_lang
67 | if reduce_rep:
68 | out_file += "_reduce_rep"
69 | out_file += ".json"
70 |
71 | content = json.load(open(in_file, "r"))
72 | num_conv = len(content)
73 |
74 | new_content = []
75 | for conv in tqdm(content):
76 | if not skip(conv, args):
77 | new_content.append(conv)
78 |
79 | print(f"return {len(new_content)} out of {len(content)}, start dump ...")
80 | json.dump(new_content, open(out_file, "w"), indent=2)
81 |
--------------------------------------------------------------------------------
/llava/data/pretty_json.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 pretty_json.py --in in.json --out out.json
4 | """
5 |
6 | import argparse
7 | import json
8 |
9 |
10 | if __name__ == "__main__":
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--in-file", type=str, required=True)
13 | parser.add_argument("--out-file", type=str, required=True)
14 | args = parser.parse_args()
15 |
16 | with open(args.in_file, "r") as fin:
17 | data = json.load(fin)
18 |
19 | with open(args.out_file, "w") as fout:
20 | json.dump(data, fout, indent=2)
21 |
--------------------------------------------------------------------------------
/llava/data/split_long_conversation.py:
--------------------------------------------------------------------------------
1 | """
2 | Split long conversations based on certain max length.
3 |
4 | Usage: python3 -m fastchat.data.split_long_conversation \
5 | --in sharegpt_clean.json \
6 | --out sharegpt_split.json \
7 | --model-name-or-path $
8 | """
9 | import argparse
10 | import json
11 | from typing import Dict, Sequence, Optional
12 |
13 | import transformers
14 | import tqdm
15 |
16 | from llava import conversation as conversation_lib
17 |
18 | DEFAULT_PAD_TOKEN = "[PAD]"
19 | BEGIN_SIGNAL = "### "
20 | END_SIGNAL = "\n"
21 |
22 |
23 | def split_sample(sample, start_idx, end_idx):
24 | # only ends in the bot because otherwise the last human part is useless.
25 | end_speaker = sample["conversations"][end_idx]["from"]
26 | end_idx = end_idx + 1 if end_speaker != "human" else end_idx
27 | return {
28 | "id": sample["id"] + "_" + str(start_idx),
29 | "conversations": sample["conversations"][start_idx:end_idx]
30 | }
31 |
32 |
33 | def split_contents(content, begin, end, tokenizer, max_length):
34 | """
35 | Keep the maximum round of conversations within the max token length constraint
36 | """
37 | content = content[begin:end]
38 | new_content = []
39 |
40 | for sample in tqdm.tqdm(content):
41 | tokenized_lens = []
42 |
43 | for c in sample["conversations"]:
44 | from_str = c["from"]
45 | if from_str.lower() == "human":
46 | from_str = conversation_lib.default_conversation.roles[0]
47 | elif from_str.lower() == "gpt":
48 | from_str = conversation_lib.default_conversation.roles[1]
49 | else:
50 | from_str = 'unknown'
51 |
52 | sentence = (BEGIN_SIGNAL + from_str + ": " + c["value"] +
53 | END_SIGNAL)
54 | length = tokenizer(sentence, return_tensors="pt", padding="longest"
55 | ).input_ids.ne(tokenizer.pad_token_id).sum().item()
56 | tokenized_lens.append(length)
57 |
58 | num_tokens = 0
59 | start_idx = 0
60 | for idx, l in enumerate(tokenized_lens):
61 | # TODO: shall we also only starts from a specific speaker?
62 | if num_tokens + l > max_length:
63 | new_content.append(split_sample(sample, start_idx, idx))
64 | start_idx = idx
65 | num_tokens = l
66 | else:
67 | num_tokens += l
68 | if idx == len(tokenized_lens) - 1:
69 | new_content.append(split_sample(sample, start_idx, idx))
70 |
71 | print(f"total: {len(content)}, new: {len(new_content)}")
72 | return new_content
73 |
74 |
75 | def main(args):
76 | content = json.load(open(args.in_file, "r"))
77 | tokenizer = transformers.AutoTokenizer.from_pretrained(
78 | args.model_name_or_path,
79 | model_max_length=args.max_length,
80 | padding_side="right",
81 | use_fast=False,
82 | )
83 | if tokenizer.pad_token is None:
84 | tokenizer.add_special_tokens(dict(pad_token=DEFAULT_PAD_TOKEN))
85 | content = split_contents(content, args.begin, args.end,
86 | tokenizer, args.max_length)
87 | json.dump(content, open(args.out_file, "w"), indent=2)
88 |
89 |
90 | if __name__ == "__main__":
91 | parser = argparse.ArgumentParser()
92 | parser.add_argument("--in-file", type=str, required=True)
93 | parser.add_argument("--out-file", type=str, default="sharegpt_split.json")
94 | parser.add_argument("--begin", type=int)
95 | parser.add_argument("--end", type=int)
96 | parser.add_argument("--model-name-or-path", type=str, required=True)
97 | parser.add_argument("--max-length", type=int, default=2304)
98 | args = parser.parse_args()
99 | main(args)
100 |
--------------------------------------------------------------------------------
/llava/eval/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/eval/.DS_Store
--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import tqdm
7 | import ray
8 | import time
9 |
10 | @ray.remote(num_cpus=4)
11 | def get_eval(content: str, max_tokens: int):
12 | while True:
13 | try:
14 | response = openai.ChatCompletion.create(
15 | model='gpt-4',
16 | messages=[{
17 | 'role': 'system',
18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19 | }, {
20 | 'role': 'user',
21 | 'content': content,
22 | }],
23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
24 | max_tokens=max_tokens,
25 | )
26 | break
27 | except openai.error.RateLimitError:
28 | pass
29 | except Exception as e:
30 | print(e)
31 | time.sleep(1)
32 |
33 | print('success!')
34 | return response['choices'][0]['message']['content']
35 |
36 |
37 | def parse_score(review):
38 | try:
39 | score_pair = review.split('\n')[0]
40 | score_pair = score_pair.replace(',', ' ')
41 | sp = score_pair.split(' ')
42 | if len(sp) == 2:
43 | return [float(sp[0]), float(sp[1])]
44 | else:
45 | print('error', review)
46 | return [-1, -1]
47 | except Exception as e:
48 | print(e)
49 | print('error', review)
50 | return [-1, -1]
51 |
52 |
53 | if __name__ == '__main__':
54 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
55 | parser.add_argument('-q', '--question')
56 | # parser.add_argument('-a', '--answer')
57 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
58 | parser.add_argument('-r', '--rule')
59 | parser.add_argument('-o', '--output')
60 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
61 | args = parser.parse_args()
62 |
63 | ray.init()
64 |
65 | f_q = open(os.path.expanduser(args.question))
66 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
67 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
68 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
69 |
70 | review_file = open(f'{args.output}', 'w')
71 |
72 | js_list = []
73 | handles = []
74 | idx = 0
75 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
76 | # if idx == 1:
77 | # break
78 |
79 | ques = json.loads(ques_js)
80 | ans1 = json.loads(ans1_js)
81 | ans2 = json.loads(ans2_js)
82 |
83 | category = json.loads(ques_js)['category']
84 | if category in rule_dict:
85 | rule = rule_dict[category]
86 | else:
87 | rule = rule_dict['default']
88 | prompt = rule['prompt']
89 | role = rule['role']
90 | content = (f'[Question]\n{ques["text"]}\n\n'
91 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
92 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
93 | f'[System]\n{prompt}\n\n')
94 | js_list.append({
95 | 'id': idx+1,
96 | 'question_id': ques['question_id'],
97 | 'answer1_id': ans1['answer_id'],
98 | 'answer2_id': ans2['answer_id'],
99 | 'category': category})
100 | idx += 1
101 | handles.append(get_eval.remote(content, args.max_tokens))
102 | # To avoid the rate limit set by OpenAI
103 | time.sleep(1)
104 |
105 | reviews = ray.get(handles)
106 | for idx, review in enumerate(reviews):
107 | scores = parse_score(review)
108 | js_list[idx]['content'] = review
109 | js_list[idx]['tuple'] = scores
110 | review_file.write(json.dumps(js_list[idx]) + '\n')
111 | review_file.close()
112 |
--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review_visual.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import tqdm
7 | import ray
8 | import time
9 |
10 | @ray.remote(num_cpus=4)
11 | def get_eval(content: str, max_tokens: int):
12 | while True:
13 | try:
14 | response = openai.ChatCompletion.create(
15 | model='gpt-4',
16 | messages=[{
17 | 'role': 'system',
18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19 | }, {
20 | 'role': 'user',
21 | 'content': content,
22 | }],
23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
24 | max_tokens=max_tokens,
25 | )
26 | break
27 | except openai.error.RateLimitError:
28 | pass
29 | except Exception as e:
30 | print(e)
31 | time.sleep(1)
32 |
33 | print('success!')
34 | return response['choices'][0]['message']['content']
35 |
36 |
37 | def parse_score(review):
38 | try:
39 | score_pair = review.split('\n')[0]
40 | score_pair = score_pair.replace(',', ' ')
41 | sp = score_pair.split(' ')
42 | if len(sp) == 2:
43 | return [float(sp[0]), float(sp[1])]
44 | else:
45 | print('error', review)
46 | return [-1, -1]
47 | except Exception as e:
48 | print(e)
49 | print('error', review)
50 | return [-1, -1]
51 |
52 |
53 | if __name__ == '__main__':
54 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
55 | parser.add_argument('-q', '--question')
56 | parser.add_argument('-c', '--context')
57 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
58 | parser.add_argument('-r', '--rule')
59 | parser.add_argument('-o', '--output')
60 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
61 | args = parser.parse_args()
62 |
63 | ray.init()
64 |
65 | f_q = open(os.path.expanduser(args.question))
66 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
67 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
68 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
69 |
70 | review_file = open(f'{args.output}', 'w')
71 |
72 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
73 | image_to_context = {context['image']: context for context in context_list}
74 |
75 | js_list = []
76 | handles = []
77 | idx = 0
78 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
79 | ques = json.loads(ques_js)
80 | ans1 = json.loads(ans1_js)
81 | ans2 = json.loads(ans2_js)
82 |
83 | inst = image_to_context[ques['image']]
84 | cap_str = '\n'.join(inst['captions'])
85 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
86 |
87 | category = json.loads(ques_js)['category']
88 | if category in rule_dict:
89 | rule = rule_dict[category]
90 | else:
91 | assert False, f"Visual QA category not found in rule file: {category}."
92 | prompt = rule['prompt']
93 | role = rule['role']
94 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
95 | f'[Question]\n{ques["text"]}\n\n'
96 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
97 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
98 | f'[System]\n{prompt}\n\n')
99 | js_list.append({
100 | 'id': idx+1,
101 | 'question_id': ques['question_id'],
102 | 'answer1_id': ans1.get('answer_id', ans1['question_id']),
103 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
104 | 'category': category})
105 | idx += 1
106 | handles.append(get_eval.remote(content, args.max_tokens))
107 | # To avoid the rate limit set by OpenAI
108 | time.sleep(1)
109 |
110 | reviews = ray.get(handles)
111 | for idx, review in enumerate(reviews):
112 | scores = parse_score(review)
113 | js_list[idx]['content'] = review
114 | js_list[idx]['tuple'] = scores
115 | review_file.write(json.dumps(js_list[idx]) + '\n')
116 | review_file.close()
117 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 |
7 |
8 | def get_args():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--base-dir', type=str)
11 | parser.add_argument('--result-file', type=str)
12 | parser.add_argument('--output-file', type=str)
13 | parser.add_argument('--output-result', type=str)
14 | parser.add_argument('--split', type=str, default='test')
15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16 | return parser.parse_args()
17 |
18 |
19 | def convert_caps(results):
20 | fakecaps = []
21 | for result in results:
22 | image_id = result['question_id']
23 | caption = result['text']
24 | fakecaps.append({"image_id": int(image_id), "caption": caption})
25 | return fakecaps
26 |
27 |
28 | def get_pred_idx(prediction, choices, options):
29 | """
30 | Get the index (e.g. 2) from the prediction (e.g. 'C')
31 | """
32 | if prediction in options[:len(choices)]:
33 | return options.index(prediction)
34 | else:
35 | return random.choice(range(len(choices)))
36 |
37 |
38 | if __name__ == "__main__":
39 | args = get_args()
40 |
41 | base_dir = args.base_dir
42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
43 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
44 | predictions = [json.loads(line) for line in open(args.result_file)]
45 | predictions = {pred['question_id']: pred for pred in predictions}
46 | split_problems = {idx: problems[idx] for idx in split_indices}
47 |
48 | results = {'correct': [], 'incorrect': []}
49 | sqa_results = {}
50 | sqa_results['acc'] = None
51 | sqa_results['correct'] = None
52 | sqa_results['count'] = None
53 | sqa_results['results'] = {}
54 | sqa_results['outputs'] = {}
55 |
56 | for prob_id, prob in split_problems.items():
57 | if prob_id not in predictions:
58 | continue
59 | pred = predictions[prob_id]
60 | pred_text = pred['text']
61 |
62 | pattern = re.compile(r'The answer is ([A-Z]).')
63 | res = pattern.findall(pred_text)
64 | if len(res) == 1:
65 | answer = res[0] # 'A', 'B', ...
66 | else:
67 | answer = "FAILED"
68 |
69 | pred_idx = get_pred_idx(answer, prob['choices'], args.options)
70 |
71 | analysis = {
72 | 'question_id': prob_id,
73 | 'parsed_ans': answer,
74 | 'ground_truth': args.options[prob['answer']],
75 | 'question': pred['prompt'],
76 | 'pred': pred_text,
77 | 'is_multimodal': '' in pred['prompt'],
78 | }
79 |
80 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
81 | sqa_results['outputs'][prob_id] = pred_text
82 |
83 | if pred_idx == prob['answer']:
84 | results['correct'].append(analysis)
85 | else:
86 | results['incorrect'].append(analysis)
87 |
88 | correct = len(results['correct'])
89 | total = len(results['correct']) + len(results['incorrect'])
90 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
91 |
92 | sqa_results['acc'] = correct / total * 100
93 | sqa_results['correct'] = correct
94 | sqa_results['count'] = total
95 |
96 | with open(args.output_file, 'w') as f:
97 | json.dump(results, f, indent=2)
98 | with open(args.output_result, 'w') as f:
99 | json.dump(sqa_results, f, indent=2)
100 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa_gpt4.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 | from collections import defaultdict
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--base-dir', type=str)
12 | parser.add_argument('--gpt4-result', type=str)
13 | parser.add_argument('--our-result', type=str)
14 | parser.add_argument('--split', type=str, default='test')
15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16 | return parser.parse_args()
17 |
18 |
19 | def convert_caps(results):
20 | fakecaps = []
21 | for result in results:
22 | image_id = result['question_id']
23 | caption = result['text']
24 | fakecaps.append({"image_id": int(image_id), "caption": caption})
25 | return fakecaps
26 |
27 |
28 | def get_pred_idx(prediction, choices, options):
29 | """
30 | Get the index (e.g. 2) from the prediction (e.g. 'C')
31 | """
32 | if prediction in options[:len(choices)]:
33 | return options.index(prediction)
34 | else:
35 | return random.choice(range(len(choices)))
36 |
37 |
38 | if __name__ == "__main__":
39 | args = get_args()
40 |
41 | base_dir = args.base_dir
42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
43 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
44 | our_predictions = [json.loads(line) for line in open(args.our_result)]
45 | our_predictions = {pred['question_id']: pred for pred in our_predictions}
46 | split_problems = {idx: problems[idx] for idx in split_indices}
47 |
48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
49 |
50 | results = defaultdict(lambda: 0)
51 |
52 | for prob_id, prob in split_problems.items():
53 | if prob_id not in our_predictions:
54 | continue
55 | if prob_id not in gpt4_predictions:
56 | continue
57 | our_pred = our_predictions[prob_id]['text']
58 | gpt4_pred = gpt4_predictions[prob_id]
59 |
60 | pattern = re.compile(r'The answer is ([A-Z]).')
61 | our_res = pattern.findall(our_pred)
62 | if len(our_res) == 1:
63 | our_answer = our_res[0] # 'A', 'B', ...
64 | else:
65 | our_answer = "FAILED"
66 | gpt4_res = pattern.findall(gpt4_pred)
67 | if len(gpt4_res) == 1:
68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ...
69 | else:
70 | gpt4_answer = "FAILED"
71 |
72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
74 |
75 | if gpt4_answer == 'FAILED':
76 | results['gpt4_failed'] += 1
77 | # continue
78 | gpt4_pred_idx = our_pred_idx
79 | # if our_pred_idx != prob['answer']:
80 | # print(our_predictions[prob_id]['prompt'])
81 | # print('-----------------')
82 | # print(f'LECTURE: {prob["lecture"]}')
83 | # print(f'SOLUTION: {prob["solution"]}')
84 | # print('=====================')
85 | else:
86 | # continue
87 | pass
88 | # gpt4_pred_idx = our_pred_idx
89 |
90 | if gpt4_pred_idx == prob['answer']:
91 | results['correct'] += 1
92 | else:
93 | results['incorrect'] += 1
94 |
95 |
96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
97 | results['correct_upperbound'] += 1
98 |
99 | correct = results['correct']
100 | total = results['correct'] + results['incorrect']
101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
104 |
105 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa_gpt4_requery.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 | from collections import defaultdict
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--base-dir', type=str)
12 | parser.add_argument('--gpt4-result', type=str)
13 | parser.add_argument('--requery-result', type=str)
14 | parser.add_argument('--our-result', type=str)
15 | parser.add_argument('--output-result', type=str)
16 | parser.add_argument('--split', type=str, default='test')
17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
18 | return parser.parse_args()
19 |
20 |
21 | def convert_caps(results):
22 | fakecaps = []
23 | for result in results:
24 | image_id = result['question_id']
25 | caption = result['text']
26 | fakecaps.append({"image_id": int(image_id), "caption": caption})
27 | return fakecaps
28 |
29 |
30 | def get_pred_idx(prediction, choices, options):
31 | """
32 | Get the index (e.g. 2) from the prediction (e.g. 'C')
33 | """
34 | if prediction in options[:len(choices)]:
35 | return options.index(prediction)
36 | else:
37 | return random.choice(range(len(choices)))
38 |
39 |
40 | if __name__ == "__main__":
41 | args = get_args()
42 |
43 | base_dir = args.base_dir
44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
45 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
46 | our_predictions = [json.loads(line) for line in open(args.our_result)]
47 | our_predictions = {pred['question_id']: pred for pred in our_predictions}
48 | split_problems = {idx: problems[idx] for idx in split_indices}
49 |
50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)]
51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions}
52 |
53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
54 |
55 | results = defaultdict(lambda: 0)
56 |
57 | sqa_results = {}
58 | sqa_results['acc'] = None
59 | sqa_results['correct'] = None
60 | sqa_results['count'] = None
61 | sqa_results['results'] = {}
62 | sqa_results['outputs'] = {}
63 |
64 | for prob_id, prob in split_problems.items():
65 | if prob_id not in our_predictions:
66 | assert False
67 | if prob_id not in gpt4_predictions:
68 | assert False
69 | our_pred = our_predictions[prob_id]['text']
70 | gpt4_pred = gpt4_predictions[prob_id]
71 | if prob_id not in requery_predictions:
72 | results['missing_requery'] += 1
73 | requery_pred = "MISSING"
74 | else:
75 | requery_pred = requery_predictions[prob_id]['text']
76 |
77 | pattern = re.compile(r'The answer is ([A-Z]).')
78 | our_res = pattern.findall(our_pred)
79 | if len(our_res) == 1:
80 | our_answer = our_res[0] # 'A', 'B', ...
81 | else:
82 | our_answer = "FAILED"
83 |
84 | requery_res = pattern.findall(requery_pred)
85 | if len(requery_res) == 1:
86 | requery_answer = requery_res[0] # 'A', 'B', ...
87 | else:
88 | requery_answer = "FAILED"
89 |
90 | gpt4_res = pattern.findall(gpt4_pred)
91 | if len(gpt4_res) == 1:
92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ...
93 | else:
94 | gpt4_answer = "FAILED"
95 |
96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options)
99 |
100 | results['total'] += 1
101 |
102 | if gpt4_answer == 'FAILED':
103 | results['gpt4_failed'] += 1
104 | if gpt4_pred_idx == prob['answer']:
105 | results['gpt4_correct'] += 1
106 | if our_pred_idx == prob['answer']:
107 | results['gpt4_ourvisual_correct'] += 1
108 | elif gpt4_pred_idx == prob['answer']:
109 | results['gpt4_correct'] += 1
110 | results['gpt4_ourvisual_correct'] += 1
111 |
112 | if our_pred_idx == prob['answer']:
113 | results['our_correct'] += 1
114 |
115 | if requery_answer == 'FAILED':
116 | sqa_results['results'][prob_id] = our_pred_idx
117 | if our_pred_idx == prob['answer']:
118 | results['requery_correct'] += 1
119 | else:
120 | sqa_results['results'][prob_id] = requery_pred_idx
121 | if requery_pred_idx == prob['answer']:
122 | results['requery_correct'] += 1
123 | else:
124 | print(f"""
125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
126 | Our ({our_answer}): {our_pred}
127 | GPT-4 ({gpt4_answer}): {gpt4_pred}
128 | Requery ({requery_answer}): {requery_pred}
129 | print("=====================================")
130 | """)
131 |
132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
133 | results['correct_upperbound'] += 1
134 |
135 | total = results['total']
136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%')
137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%')
138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%')
140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%')
141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
142 |
143 | sqa_results['acc'] = results["requery_correct"] / total * 100
144 | sqa_results['correct'] = results["requery_correct"]
145 | sqa_results['count'] = total
146 |
147 | with open(args.output_result, 'w') as f:
148 | json.dump(sqa_results, f, indent=2)
149 |
150 |
--------------------------------------------------------------------------------
/llava/eval/generate_webpage_data_from_table.py:
--------------------------------------------------------------------------------
1 | """Generate json file for webpage."""
2 | import json
3 | import os
4 | import re
5 |
6 | # models = ['llama', 'alpaca', 'gpt35', 'bard']
7 | models = ['vicuna']
8 |
9 |
10 | def read_jsonl(path: str, key: str=None):
11 | data = []
12 | with open(os.path.expanduser(path)) as f:
13 | for line in f:
14 | if not line:
15 | continue
16 | data.append(json.loads(line))
17 | if key is not None:
18 | data.sort(key=lambda x: x[key])
19 | data = {item[key]: item for item in data}
20 | return data
21 |
22 |
23 | def trim_hanging_lines(s: str, n: int) -> str:
24 | s = s.strip()
25 | for _ in range(n):
26 | s = s.split('\n', 1)[1].strip()
27 | return s
28 |
29 |
30 | if __name__ == '__main__':
31 | questions = read_jsonl('table/question.jsonl', key='question_id')
32 |
33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
39 |
40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')
45 |
46 | records = []
47 | for qid in questions.keys():
48 | r = {
49 | 'id': qid,
50 | 'category': questions[qid]['category'],
51 | 'question': questions[qid]['text'],
52 | 'answers': {
53 | # 'alpaca': alpaca_answers[qid]['text'],
54 | # 'llama': llama_answers[qid]['text'],
55 | # 'bard': bard_answers[qid]['text'],
56 | # 'gpt35': gpt35_answers[qid]['text'],
57 | 'vicuna': vicuna_answers[qid]['text'],
58 | 'ours': ours_answers[qid]['text'],
59 | },
60 | 'evaluations': {
61 | # 'alpaca': review_alpaca[qid]['text'],
62 | # 'llama': review_llama[qid]['text'],
63 | # 'bard': review_bard[qid]['text'],
64 | 'vicuna': review_vicuna[qid]['content'],
65 | # 'gpt35': review_gpt35[qid]['text'],
66 | },
67 | 'scores': {
68 | 'vicuna': review_vicuna[qid]['tuple'],
69 | # 'alpaca': review_alpaca[qid]['score'],
70 | # 'llama': review_llama[qid]['score'],
71 | # 'bard': review_bard[qid]['score'],
72 | # 'gpt35': review_gpt35[qid]['score'],
73 | },
74 | }
75 |
76 | # cleanup data
77 | cleaned_evals = {}
78 | for k, v in r['evaluations'].items():
79 | v = v.strip()
80 | lines = v.split('\n')
81 | # trim the first line if it's a pair of numbers
82 | if re.match(r'\d+[, ]+\d+', lines[0]):
83 | lines = lines[1:]
84 | v = '\n'.join(lines)
85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
86 |
87 | r['evaluations'] = cleaned_evals
88 | records.append(r)
89 |
90 | # Reorder the records, this is optional
91 | for r in records:
92 | if r['id'] <= 20:
93 | r['id'] += 60
94 | else:
95 | r['id'] -= 20
96 | for r in records:
97 | if r['id'] <= 50:
98 | r['id'] += 10
99 | elif 50 < r['id'] <= 60:
100 | r['id'] -= 50
101 | for r in records:
102 | if r['id'] == 7:
103 | r['id'] = 1
104 | elif r['id'] < 7:
105 | r['id'] += 1
106 |
107 | records.sort(key=lambda x: x['id'])
108 |
109 | # Write to file
110 | with open('webpage/data.json', 'w') as f:
111 | json.dump({'questions': records, 'models': models}, f, indent=2)
112 |
--------------------------------------------------------------------------------
/llava/eval/model_qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
3 | import torch
4 | import os
5 | import json
6 | from tqdm import tqdm
7 | import shortuuid
8 |
9 | from llava.conversation import default_conversation
10 | from llava.utils import disable_torch_init
11 |
12 |
13 | # new stopping implementation
14 | class KeywordsStoppingCriteria(StoppingCriteria):
15 | def __init__(self, keywords, tokenizer, input_ids):
16 | self.keywords = keywords
17 | self.tokenizer = tokenizer
18 | self.start_len = None
19 | self.input_ids = input_ids
20 |
21 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
22 | if self.start_len is None:
23 | self.start_len = self.input_ids.shape[1]
24 | else:
25 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
26 | for keyword in self.keywords:
27 | if keyword in outputs:
28 | return True
29 | return False
30 |
31 |
32 | @torch.inference_mode()
33 | def eval_model(model_name, questions_file, answers_file):
34 | # Model
35 | disable_torch_init()
36 | model_name = os.path.expanduser(model_name)
37 | tokenizer = AutoTokenizer.from_pretrained(model_name)
38 | model = AutoModelForCausalLM.from_pretrained(model_name,
39 | torch_dtype=torch.float16).cuda()
40 |
41 |
42 | ques_file = open(os.path.expanduser(questions_file), "r")
43 | ans_file = open(os.path.expanduser(answers_file), "w")
44 | for i, line in enumerate(tqdm(ques_file)):
45 | idx = json.loads(line)["question_id"]
46 | qs = json.loads(line)["text"]
47 | cat = json.loads(line)["category"]
48 | conv = default_conversation.copy()
49 | conv.append_message(conv.roles[0], qs)
50 | prompt = conv.get_prompt()
51 | inputs = tokenizer([prompt])
52 | input_ids = torch.as_tensor(inputs.input_ids).cuda()
53 | stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids)
54 | output_ids = model.generate(
55 | input_ids,
56 | do_sample=True,
57 | temperature=0.7,
58 | max_new_tokens=1024,
59 | stopping_criteria=[stopping_criteria])
60 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
61 | try:
62 | index = outputs.index(conv.sep, len(prompt))
63 | except ValueError:
64 | outputs += conv.sep
65 | index = outputs.index(conv.sep, len(prompt))
66 |
67 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
68 | ans_id = shortuuid.uuid()
69 | ans_file.write(json.dumps({"question_id": idx,
70 | "text": outputs,
71 | "answer_id": ans_id,
72 | "model_id": model_name,
73 | "metadata": {}}) + "\n")
74 | ans_file.flush()
75 | ans_file.close()
76 |
77 | if __name__ == "__main__":
78 | parser = argparse.ArgumentParser()
79 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
80 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
81 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
82 | args = parser.parse_args()
83 |
84 | eval_model(args.model_name, args.question_file, args.answers_file)
85 |
--------------------------------------------------------------------------------
/llava/eval/qa_baseline_gpt35.py:
--------------------------------------------------------------------------------
1 | """Generate answers with GPT-3.5"""
2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work
3 | import argparse
4 | import json
5 | import os
6 | import time
7 | import concurrent.futures
8 |
9 | import openai
10 | import tqdm
11 | import shortuuid
12 |
13 | MODEL = 'gpt-3.5-turbo'
14 | MODEL_ID = 'gpt-3.5-turbo:20230327'
15 |
16 | def get_answer(question_id: int, question: str, max_tokens: int):
17 | ans = {
18 | 'answer_id': shortuuid.uuid(),
19 | 'question_id': question_id,
20 | 'model_id': MODEL_ID,
21 | }
22 | for _ in range(3):
23 | try:
24 | response = openai.ChatCompletion.create(
25 | model=MODEL,
26 | messages=[{
27 | 'role': 'system',
28 | 'content': 'You are a helpful assistant.'
29 | }, {
30 | 'role': 'user',
31 | 'content': question,
32 | }],
33 | max_tokens=max_tokens,
34 | )
35 | ans['text'] = response['choices'][0]['message']['content']
36 | return ans
37 | except Exception as e:
38 | print('[ERROR]', e)
39 | ans['text'] = '#ERROR#'
40 | time.sleep(1)
41 | return ans
42 |
43 |
44 | if __name__ == '__main__':
45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
46 | parser.add_argument('-q', '--question')
47 | parser.add_argument('-o', '--output')
48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
49 | args = parser.parse_args()
50 |
51 | questions_dict = {}
52 | with open(os.path.expanduser(args.question)) as f:
53 | for line in f:
54 | if not line:
55 | continue
56 | q = json.loads(line)
57 | questions_dict[q['question_id']] = q['text']
58 |
59 | answers = []
60 |
61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
62 | futures = []
63 | for qid, question in questions_dict.items():
64 | future = executor.submit(get_answer, qid, question, args.max_tokens)
65 | futures.append(future)
66 |
67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
68 | answers.append(future.result())
69 |
70 | answers.sort(key=lambda x: x['question_id'])
71 |
72 | with open(os.path.expanduser(args.output), 'w') as f:
73 | table = [json.dumps(ans) for ans in answers]
74 | f.write('\n'.join(table))
75 |
--------------------------------------------------------------------------------
/llava/eval/run_llava.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from transformers import AutoTokenizer, AutoModelForCausalLM
3 | import torch
4 | import os
5 | from llava.conversation import conv_templates, SeparatorStyle
6 | from llava.utils import disable_torch_init
7 | from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
8 | from llava.model import *
9 | from llava.model.utils import KeywordsStoppingCriteria
10 |
11 | from PIL import Image
12 |
13 | import os
14 | import requests
15 | from PIL import Image
16 | from io import BytesIO
17 |
18 |
19 | DEFAULT_IMAGE_TOKEN = ""
20 | DEFAULT_IMAGE_PATCH_TOKEN = ""
21 | DEFAULT_IM_START_TOKEN = ""
22 | DEFAULT_IM_END_TOKEN = ""
23 |
24 |
25 | def load_image(image_file):
26 | if image_file.startswith('http') or image_file.startswith('https'):
27 | response = requests.get(image_file)
28 | image = Image.open(BytesIO(response.content)).convert('RGB')
29 | else:
30 | image = Image.open(image_file).convert('RGB')
31 | return image
32 |
33 |
34 | def eval_model(args):
35 | # Model
36 | disable_torch_init()
37 | model_name = os.path.expanduser(args.model_name)
38 | tokenizer = AutoTokenizer.from_pretrained(model_name)
39 |
40 | if "mpt" in model_name.lower():
41 | model = LlavaMPTForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
42 | else:
43 | model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda()
44 | image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16)
45 |
46 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
47 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
48 | if mm_use_im_start_end:
49 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
50 |
51 | vision_tower = model.get_model().vision_tower[0]
52 | if vision_tower.device.type == 'meta':
53 | vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda()
54 | model.get_model().vision_tower[0] = vision_tower
55 | else:
56 | vision_tower.to(device='cuda', dtype=torch.float16)
57 | vision_config = vision_tower.config
58 | vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0]
59 | vision_config.use_im_start_end = mm_use_im_start_end
60 | if mm_use_im_start_end:
61 | vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
62 | image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2
63 |
64 | qs = args.query
65 | if mm_use_im_start_end:
66 | qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
67 | else:
68 | qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
69 |
70 | if "v1" in model_name.lower():
71 | conv_mode = "llava_v1"
72 | elif "mpt" in model_name.lower():
73 | conv_mode = "mpt_multimodal"
74 | else:
75 | conv_mode = "multimodal"
76 |
77 | if args.conv_mode is not None and conv_mode != args.conv_mode:
78 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
79 | else:
80 | args.conv_mode = conv_mode
81 |
82 | conv = conv_templates[args.conv_mode].copy()
83 | conv.append_message(conv.roles[0], qs)
84 | conv.append_message(conv.roles[1], None)
85 | prompt = conv.get_prompt()
86 | inputs = tokenizer([prompt])
87 |
88 | image = load_image(args.image_file)
89 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
90 |
91 | input_ids = torch.as_tensor(inputs.input_ids).cuda()
92 |
93 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
94 | keywords = [stop_str]
95 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
96 |
97 | with torch.inference_mode():
98 | output_ids = model.generate(
99 | input_ids,
100 | images=image_tensor.unsqueeze(0).half().cuda(),
101 | do_sample=True,
102 | temperature=0.2,
103 | max_new_tokens=1024,
104 | stopping_criteria=[stopping_criteria])
105 |
106 | input_token_len = input_ids.shape[1]
107 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
108 | if n_diff_input_output > 0:
109 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
110 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
111 | outputs = outputs.strip()
112 | if outputs.endswith(stop_str):
113 | outputs = outputs[:-len(stop_str)]
114 | outputs = outputs.strip()
115 | print(outputs)
116 |
117 | if __name__ == "__main__":
118 | parser = argparse.ArgumentParser()
119 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
120 | parser.add_argument("--image-file", type=str, required=True)
121 | parser.add_argument("--query", type=str, required=True)
122 | parser.add_argument("--conv-mode", type=str, default=None)
123 | args = parser.parse_args()
124 |
125 | eval_model(args)
126 |
--------------------------------------------------------------------------------
/llava/eval/summarize_gpt_review.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 |
7 |
8 | if __name__ == '__main__':
9 | base_dir = "vqa/reviews/coco2014_val80"
10 | review_files = [x for x in os.listdir(base_dir) if x.endswith('.jsonl') and x.startswith('gpt4_text')]
11 |
12 | for review_file in sorted(review_files):
13 | config = review_file.replace('gpt4_text_', '').replace('.jsonl', '')
14 | scores = defaultdict(list)
15 | print(f'GPT-4 vs. {config}')
16 | with open(os.path.join(base_dir, review_file)) as f:
17 | for review_str in f:
18 | review = json.loads(review_str)
19 | scores[review['category']].append(review['tuple'])
20 | scores['all'].append(review['tuple'])
21 | for k, v in scores.items():
22 | stats = np.asarray(v).mean(0).tolist()
23 | stats = [round(x, 3) for x in stats]
24 | print(k, stats, round(stats[1]/stats[0]*100, 1))
25 | print('=================================')
26 |
27 |
--------------------------------------------------------------------------------
/llava/eval/table/model.jsonl:
--------------------------------------------------------------------------------
1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"}
2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"}
3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"}
4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"}
5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"}
6 |
--------------------------------------------------------------------------------
/llava/eval/table/prompt.jsonl:
--------------------------------------------------------------------------------
1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"}
2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"}
3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"}
4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"}
5 |
--------------------------------------------------------------------------------
/llava/eval/table/reviewer.jsonl:
--------------------------------------------------------------------------------
1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"}
2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"}
3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
5 |
--------------------------------------------------------------------------------
/llava/eval/table/rule.json:
--------------------------------------------------------------------------------
1 | {
2 | "coding": {"role": "Assistant", "prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."},
3 | "math": {"role": "Assistant", "prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."},
4 | "default": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
5 | "conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
6 | "detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."},
7 | "complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}
8 | }
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/alpaca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/eval/webpage/figures/alpaca.png
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/bard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/eval/webpage/figures/bard.jpg
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/chatgpt.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/llama.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/eval/webpage/figures/llama.jpg
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/vicuna.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/eval/webpage/figures/vicuna.jpeg
--------------------------------------------------------------------------------
/llava/eval/webpage/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots
7 |
8 |
9 |
10 |
11 |
12 |
13 |
32 |
33 |
34 |
Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots
35 |
36 |
37 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
![other logo]()
63 |
64 |
65 |
66 |
67 |
68 |

69 |
70 |
71 |
72 |
73 |

74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
85 |
86 |
87 |
122 |
123 |
124 |
133 |
134 |
135 |
136 |
137 |
This website is co-authored with GPT-4.
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
160 |
161 |
162 |
163 |
--------------------------------------------------------------------------------
/llava/eval/webpage/styles.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
3 | background-color: #f8f9fa;
4 | }
5 |
6 | .navbar-dark .navbar-nav .nav-link {
7 | color: #f1cf68;
8 | font-size: 1.1rem;
9 | padding: 0.5rem 0.6rem;
10 | }
11 |
12 | .card-header {
13 | font-weight: bold;
14 | }
15 |
16 | .card {
17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
18 | transition: 0.3s;
19 | }
20 |
21 | .card:hover {
22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
23 | }
24 |
25 | button {
26 | transition: background-color 0.3s;
27 | }
28 |
29 | button:hover {
30 | background-color: #007bff;
31 | }
32 |
33 | @media (max-width: 767px) {
34 | .form-row .form-group {
35 | margin-bottom: 10px;
36 | }
37 | }
38 |
39 | /* Extra styles */
40 |
41 | .expandable-card .card-text-container {
42 | max-height: 200px;
43 | overflow-y: hidden;
44 | position: relative;
45 | }
46 |
47 | .expandable-card.expanded .card-text-container {
48 | max-height: none;
49 | }
50 |
51 | .expand-btn {
52 | position: relative;
53 | display: none;
54 | background-color: rgba(255, 255, 255, 0.8);
55 | color: #510c75;
56 | border-color: transparent;
57 | }
58 |
59 | .expand-btn:hover {
60 | background-color: rgba(200, 200, 200, 0.8);
61 | text-decoration: none;
62 | border-color: transparent;
63 | color: #510c75;
64 | }
65 |
66 | .expand-btn:focus {
67 | outline: none;
68 | text-decoration: none;
69 | }
70 |
71 | .expandable-card:not(.expanded) .card-text-container:after {
72 | content: "";
73 | position: absolute;
74 | bottom: 0;
75 | left: 0;
76 | width: 100%;
77 | height: 90px;
78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1));
79 | }
80 |
81 | .expandable-card:not(.expanded) .expand-btn {
82 | margin-top: -40px;
83 | }
84 |
85 | .card-body {
86 | padding-bottom: 5px;
87 | }
88 |
89 | .vertical-flex-layout {
90 | justify-content: center;
91 | align-items: center;
92 | height: 100%;
93 | display: flex;
94 | flex-direction: column;
95 | gap: 5px;
96 | }
97 |
98 | .figure-img {
99 | max-width: 100%;
100 | height: auto;
101 | }
102 |
103 | .adjustable-font-size {
104 | font-size: calc(0.5rem + 2vw);
105 | }
106 |
--------------------------------------------------------------------------------
/llava/model/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/model/.DS_Store
--------------------------------------------------------------------------------
/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .llava import LlavaLlamaForCausalLM, LlavaConfig
2 | from .llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig
3 |
--------------------------------------------------------------------------------
/llava/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava import LlavaLlamaForCausalLM
11 |
12 |
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32 | bparam = base.state_dict()[name]
33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34 |
35 | print("Saving target model")
36 | delta.save_pretrained(target_model_path)
37 | delta_tokenizer.save_pretrained(target_model_path)
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--base-model-path", type=str, required=True)
43 | parser.add_argument("--target-model-path", type=str, required=True)
44 | parser.add_argument("--delta-path", type=str, required=True)
45 |
46 | args = parser.parse_args()
47 |
48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
49 |
--------------------------------------------------------------------------------
/llava/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | from llava.model import *
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def consolidate_ckpt(src_path, dst_path):
14 | print("Loading model")
15 | auto_upgrade(src_path)
16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path)
18 | src_model.save_pretrained(dst_path)
19 | src_tokenizer.save_pretrained(dst_path)
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--src", type=str, required=True)
25 | parser.add_argument("--dst", type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 | consolidate_ckpt(args.src, args.dst)
30 |
--------------------------------------------------------------------------------
/llava/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31 | bparam = base.state_dict()[name]
32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/llava/model/mpt/adapt_tokenizer.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
4 | NUM_SENTINEL_TOKENS: int = 100
5 |
6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer):
7 | """Adds sentinel tokens and padding token (if missing).
8 |
9 | Expands the tokenizer vocabulary to include sentinel tokens
10 | used in mixture-of-denoiser tasks as well as a padding token.
11 |
12 | All added tokens are added as special tokens. No tokens are
13 | added if sentinel tokens and padding token already exist.
14 | """
15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)]
16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True)
17 | if tokenizer.pad_token is None:
18 | tokenizer.add_tokens('', special_tokens=True)
19 | tokenizer.pad_token = ''
20 | assert tokenizer.pad_token_id is not None
21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)])
22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids
23 | tokenizer.sentinel_token_ids = _sentinel_token_ids
24 |
25 | class AutoTokenizerForMOD(AutoTokenizer):
26 | """AutoTokenizer + Adaptation for MOD.
27 |
28 | A simple wrapper around AutoTokenizer to make instantiating
29 | an MOD-adapted tokenizer a bit easier.
30 |
31 | MOD-adapted tokenizers have sentinel tokens (e.g., ),
32 | a padding token, and a property to get the token ids of the
33 | sentinel tokens.
34 | """
35 |
36 | @classmethod
37 | def from_pretrained(cls, *args, **kwargs):
38 | """See `AutoTokenizer.from_pretrained` docstring."""
39 | tokenizer = super().from_pretrained(*args, **kwargs)
40 | adapt_tokenizer_for_denoising(tokenizer)
41 | return tokenizer
--------------------------------------------------------------------------------
/llava/model/mpt/blocks.py:
--------------------------------------------------------------------------------
1 | """GPT Blocks used for the GPT Model."""
2 | from typing import Dict, Optional, Tuple
3 | import torch
4 | import torch.nn as nn
5 | from .attention import ATTN_CLASS_REGISTRY
6 | from .norm import NORM_CLASS_REGISTRY
7 |
8 | class MPTMLP(nn.Module):
9 |
10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
11 | super().__init__()
12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
13 | self.act = nn.GELU(approximate='none')
14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
15 | self.down_proj._is_residual = True
16 |
17 | def forward(self, x):
18 | return self.down_proj(self.act(self.up_proj(x)))
19 |
20 | class MPTBlock(nn.Module):
21 |
22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
23 | del kwargs
24 | super().__init__()
25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27 | self.norm_1 = norm_class(d_model, device=device)
28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
29 | self.norm_2 = norm_class(d_model, device=device)
30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop)
32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
33 |
34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35 | a = self.norm_1(x)
36 | (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37 | x = x + self.resid_attn_dropout(b)
38 | m = self.norm_2(x)
39 | n = self.ffn(m)
40 | x = x + self.resid_ffn_dropout(n)
41 | return (x, past_key_value)
--------------------------------------------------------------------------------
/llava/model/mpt/meta_init_context.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | import torch
3 | import torch.nn as nn
4 |
5 | @contextmanager
6 | def init_empty_weights(include_buffers: bool=False):
7 | """Meta initialization context manager.
8 |
9 | A context manager under which models are initialized with all parameters
10 | on the meta device, therefore creating an empty model. Useful when just
11 | initializing the model would blow the available RAM.
12 |
13 | Args:
14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or
15 | not to also put all buffers on the meta device while initializing.
16 |
17 | Example:
18 | ```python
19 | import torch.nn as nn
20 |
21 | # Initialize a model with 100 billions parameters in no time and without using any RAM.
22 | with init_empty_weights():
23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
24 | ```
25 |
26 |
27 |
28 | Any model created under this context manager has no weights. As such you can't do something like
29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
30 |
31 |
32 | """
33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f:
34 | yield f
35 |
36 | @contextmanager
37 | def init_on_device(device: torch.device, include_buffers: bool=False):
38 | """Device initialization context manager.
39 |
40 | A context manager under which models are initialized with all parameters
41 | on the specified device.
42 |
43 | Args:
44 | device (`torch.device`): Device to initialize all parameters on.
45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or
46 | not to also put all buffers on the meta device while initializing.
47 |
48 | Example:
49 | ```python
50 | import torch.nn as nn
51 |
52 | with init_on_device(device=torch.device("cuda")):
53 | tst = nn.Liner(100, 100) # on `cuda` device
54 | ```
55 | """
56 | old_register_parameter = nn.Module.register_parameter
57 | if include_buffers:
58 | old_register_buffer = nn.Module.register_buffer
59 |
60 | def register_empty_parameter(module, name, param):
61 | old_register_parameter(module, name, param)
62 | if param is not None:
63 | param_cls = type(module._parameters[name])
64 | kwargs = module._parameters[name].__dict__
65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
66 |
67 | def register_empty_buffer(module, name, buffer):
68 | old_register_buffer(module, name, buffer)
69 | if buffer is not None:
70 | module._buffers[name] = module._buffers[name].to(device)
71 | if include_buffers:
72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']}
73 | else:
74 | tensor_constructors_to_patch = {}
75 |
76 | def patch_tensor_constructor(fn):
77 |
78 | def wrapper(*args, **kwargs):
79 | kwargs['device'] = device
80 | return fn(*args, **kwargs)
81 | return wrapper
82 | try:
83 | nn.Module.register_parameter = register_empty_parameter
84 | if include_buffers:
85 | nn.Module.register_buffer = register_empty_buffer
86 | for torch_function_name in tensor_constructors_to_patch.keys():
87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
88 | yield
89 | finally:
90 | nn.Module.register_parameter = old_register_parameter
91 | if include_buffers:
92 | nn.Module.register_buffer = old_register_buffer
93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items():
94 | setattr(torch, torch_function_name, old_torch_function)
--------------------------------------------------------------------------------
/llava/model/mpt/norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | def _cast_if_autocast_enabled(tensor):
4 | if torch.is_autocast_enabled():
5 | if tensor.device.type == 'cuda':
6 | dtype = torch.get_autocast_gpu_dtype()
7 | elif tensor.device.type == 'cpu':
8 | dtype = torch.get_autocast_cpu_dtype()
9 | else:
10 | raise NotImplementedError()
11 | return tensor.to(dtype=dtype)
12 | return tensor
13 |
14 | class LPLayerNorm(torch.nn.LayerNorm):
15 |
16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype)
18 |
19 | def forward(self, x):
20 | module_device = x.device
21 | downcast_x = _cast_if_autocast_enabled(x)
22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
24 | with torch.autocast(enabled=False, device_type=module_device.type):
25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
26 |
27 | def rms_norm(x, weight=None, eps=1e-05):
28 | output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
29 | if weight is not None:
30 | return output * weight
31 | return output
32 |
33 | class RMSNorm(torch.nn.Module):
34 |
35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
36 | super().__init__()
37 | self.eps = eps
38 | if weight:
39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device))
40 | else:
41 | self.register_parameter('weight', None)
42 |
43 | def forward(self, x):
44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype)
45 |
46 | class LPRMSNorm(RMSNorm):
47 |
48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None):
49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device)
50 |
51 | def forward(self, x):
52 | downcast_x = _cast_if_autocast_enabled(x)
53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
54 | with torch.autocast(enabled=False, device_type=x.device.type):
55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype)
56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm}
--------------------------------------------------------------------------------
/llava/model/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from llava.model import *
3 | from transformers import AutoConfig, StoppingCriteria
4 |
5 |
6 | def auto_upgrade(config):
7 | cfg = AutoConfig.from_pretrained(config)
8 | if 'llava' in config and 'llava' not in cfg.model_type:
9 | assert cfg.model_type == 'llama'
10 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
11 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
12 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
13 | if confirm.lower() in ["y", "yes"]:
14 | print("Upgrading checkpoint...")
15 | assert len(cfg.architectures) == 1
16 | setattr(cfg.__class__, "model_type", "llava")
17 | cfg.architectures[0] = 'LlavaLlamaForCausalLM'
18 | cfg.save_pretrained(config)
19 | print("Checkpoint upgraded.")
20 | else:
21 | print("Checkpoint upgrade aborted.")
22 | exit(1)
23 |
24 |
25 |
26 | class KeywordsStoppingCriteria(StoppingCriteria):
27 | def __init__(self, keywords, tokenizer, input_ids):
28 | self.keywords = keywords
29 | self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
30 | self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
31 | self.tokenizer = tokenizer
32 | self.start_len = None
33 | self.input_ids = input_ids
34 |
35 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
36 | if self.start_len is None:
37 | self.start_len = self.input_ids.shape[1]
38 | else:
39 | for keyword_id in self.keyword_ids:
40 | if output_ids[0, -1] == keyword_id:
41 | return True
42 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
43 | for keyword in self.keywords:
44 | if keyword in outputs:
45 | return True
46 | return False
47 |
--------------------------------------------------------------------------------
/llava/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "llava"
7 | version = "1.0.1"
8 | description = "Towards GPT-4 like large language and visual assistant."
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "einops", "fastapi", "gradio==3.35.2", "markdown2[all]", "numpy",
17 | "requests", "sentencepiece", "tokenizers>=0.12.1",
18 | "torch", "torchvision", "uvicorn", "wandb",
19 | "shortuuid", "httpx==0.24.0",
20 | "deepspeed==0.9.5",
21 | "peft==0.4.0",
22 | "transformers==4.31.0",
23 | "accelerate==0.21.0",
24 | "bitsandbytes==0.41.0",
25 | "scikit-learn==1.2.2",
26 | "sentencepiece==0.1.99",
27 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
28 | "gradio_client==0.2.9",
29 | "ipykernel" # for jupyter notebook
30 | ]
31 |
32 | [project.urls]
33 | "Homepage" = "https://llava-vl.github.io"
34 | "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues"
35 |
36 | [tool.setuptools.packages.find]
37 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
38 |
39 | [tool.wheel]
40 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
41 |
--------------------------------------------------------------------------------
/llava/serve/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/.DS_Store
--------------------------------------------------------------------------------
/llava/serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/__init__.py
--------------------------------------------------------------------------------
/llava/serve/cli.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.serve.cli --model ~/model_weights/llama-7b
4 | """
5 | import argparse
6 | import time
7 |
8 | import torch
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 |
11 | from llava.conversation import conv_templates, SeparatorStyle
12 |
13 |
14 | @torch.inference_mode()
15 | def generate_stream(tokenizer, model, params, device,
16 | context_len=2048, stream_interval=2):
17 | """Adapted from fastchat/serve/model_worker.py::generate_stream"""
18 |
19 | prompt = params["prompt"]
20 | l_prompt = len(prompt)
21 | temperature = float(params.get("temperature", 1.0))
22 | max_new_tokens = int(params.get("max_new_tokens", 256))
23 | stop_str = params.get("stop", None)
24 |
25 | input_ids = tokenizer(prompt).input_ids
26 | output_ids = list(input_ids)
27 |
28 | max_src_len = context_len - max_new_tokens - 8
29 | input_ids = input_ids[-max_src_len:]
30 |
31 | for i in range(max_new_tokens):
32 | if i == 0:
33 | out = model(
34 | torch.as_tensor([input_ids], device=device), use_cache=True)
35 | logits = out.logits
36 | past_key_values = out.past_key_values
37 | else:
38 | attention_mask = torch.ones(
39 | 1, past_key_values[0][0].shape[-2] + 1, device=device)
40 | out = model(input_ids=torch.as_tensor([[token]], device=device),
41 | use_cache=True,
42 | attention_mask=attention_mask,
43 | past_key_values=past_key_values)
44 | logits = out.logits
45 | past_key_values = out.past_key_values
46 |
47 | last_token_logits = logits[0][-1]
48 | if temperature < 1e-4:
49 | token = int(torch.argmax(last_token_logits))
50 | else:
51 | probs = torch.softmax(last_token_logits / temperature, dim=-1)
52 | token = int(torch.multinomial(probs, num_samples=1))
53 |
54 | output_ids.append(token)
55 |
56 | if token == tokenizer.eos_token_id:
57 | stopped = True
58 | else:
59 | stopped = False
60 |
61 | if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped:
62 | output = tokenizer.decode(output_ids, skip_special_tokens=True)
63 | pos = output.rfind(stop_str, l_prompt)
64 | if pos != -1:
65 | output = output[:pos]
66 | stopped = True
67 | yield output
68 |
69 | if stopped:
70 | break
71 |
72 | del past_key_values
73 |
74 |
75 | def main(args):
76 | model_name = args.model_name
77 | num_gpus = args.num_gpus
78 |
79 | # Model
80 | if args.device == "cuda":
81 | kwargs = {"torch_dtype": torch.float16}
82 | if num_gpus == "auto":
83 | kwargs["device_map"] = "auto"
84 | else:
85 | num_gpus = int(num_gpus)
86 | if num_gpus != 1:
87 | kwargs.update({
88 | "device_map": "auto",
89 | "max_memory": {i: "13GiB" for i in range(num_gpus)},
90 | })
91 | elif args.device == "cpu":
92 | kwargs = {}
93 | else:
94 | raise ValueError(f"Invalid device: {args.device}")
95 |
96 | tokenizer = AutoTokenizer.from_pretrained(model_name)
97 | model = AutoModelForCausalLM.from_pretrained(model_name,
98 | low_cpu_mem_usage=True, **kwargs)
99 |
100 | if args.device == "cuda" and num_gpus == 1:
101 | model.cuda()
102 |
103 | # Chat
104 | conv = conv_templates[args.conv_template].copy()
105 | while True:
106 | try:
107 | inp = input(f"{conv.roles[0]}: ")
108 | except EOFError:
109 | inp = ""
110 | if not inp:
111 | print("exit...")
112 | break
113 |
114 | conv.append_message(conv.roles[0], inp)
115 | conv.append_message(conv.roles[1], None)
116 | prompt = conv.get_prompt()
117 |
118 | params = {
119 | "model": model_name,
120 | "prompt": prompt,
121 | "temperature": args.temperature,
122 | "max_new_tokens": args.max_new_tokens,
123 | "stop": conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2,
124 | }
125 |
126 | print(f"{conv.roles[1]}: ", end="", flush=True)
127 | pre = 0
128 | for outputs in generate_stream(tokenizer, model, params, args.device):
129 | outputs = outputs[len(prompt) + 1:].strip()
130 | outputs = outputs.split(" ")
131 | now = len(outputs)
132 | if now - 1 > pre:
133 | print(" ".join(outputs[pre:now-1]), end=" ", flush=True)
134 | pre = now - 1
135 | print(" ".join(outputs[pre:]), flush=True)
136 |
137 | conv.messages[-1][-1] = " ".join(outputs)
138 |
139 | if args.debug:
140 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
141 |
142 |
143 | if __name__ == "__main__":
144 | parser = argparse.ArgumentParser()
145 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
146 | parser.add_argument("--num-gpus", type=str, default="1")
147 | parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
148 | parser.add_argument("--conv-template", type=str, default="v1")
149 | parser.add_argument("--temperature", type=float, default=0.7)
150 | parser.add_argument("--max-new-tokens", type=int, default=512)
151 | parser.add_argument("--debug", action="store_true")
152 | args = parser.parse_args()
153 | main(args)
154 |
--------------------------------------------------------------------------------
/llava/serve/examples/CornellTech.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/CornellTech.png
--------------------------------------------------------------------------------
/llava/serve/examples/crying_boy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/crying_boy.jpg
--------------------------------------------------------------------------------
/llava/serve/examples/extreme_ironing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/extreme_ironing.jpg
--------------------------------------------------------------------------------
/llava/serve/examples/london.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/london.jpg
--------------------------------------------------------------------------------
/llava/serve/examples/math.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/math.jpg
--------------------------------------------------------------------------------
/llava/serve/examples/math.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/math.png
--------------------------------------------------------------------------------
/llava/serve/examples/poo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/poo.jpg
--------------------------------------------------------------------------------
/llava/serve/examples/porsche911.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/porsche911.jpeg
--------------------------------------------------------------------------------
/llava/serve/examples/porsche911.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/porsche911.png
--------------------------------------------------------------------------------
/llava/serve/examples/steak.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/steak.jpg
--------------------------------------------------------------------------------
/llava/serve/examples/thief.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/thief.jpg
--------------------------------------------------------------------------------
/llava/serve/examples/waterview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/waterview.jpg
--------------------------------------------------------------------------------
/llava/serve/gateway/README.md:
--------------------------------------------------------------------------------
1 | # fastchat Nginx Gateway
2 |
3 | ## Purpose of the Gateway
4 |
5 | The Nginx gateway serves the following purposes:
6 |
7 | 1. Protects Gradio servers by acting as a firewall.
8 | 2. Facilitates dynamic mounting and unmounting of Gradio servers.
9 | 3. Provides load balancing for Gradio servers.
10 | 4. Offers additional security features, such as total connection limit.
11 | 5. Reduces attack surface by requiring only a single public port to be exposed for serving.
12 |
13 | ## Deployment and Updating of the Gateway
14 |
15 | ### Installing Nginx
16 |
17 | On Debian-based distributions (e.g., Ubuntu):
18 |
19 | ```bash
20 | sudo apt update
21 | sudo apt install nginx
22 | ```
23 | On Red Hat-based distributions (e.g., CentOS, Fedora):
24 |
25 | ```bash
26 | sudo yum install epel-release
27 | sudo yum install nginx
28 | ```
29 |
30 | ### Deployment
31 |
32 | Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission).
33 |
34 | Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server.
35 |
36 | Modify `upstream websocket` to configure Gradio servers behind the gateway.
37 |
38 | Lastly, update Nginx.
39 |
40 |
41 | ### HTTPS Deployment with a Public Domain URL
42 |
43 | Make sure you obtain the HTTPS certificate and the private key used to generate the certificate.
44 |
45 | Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields.
46 |
47 | If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url.
48 |
49 | ### Updating
50 |
51 | Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service:
52 |
53 | ```bash
54 | sudo nginx -t # check `/etc/nginx/nginx.conf`
55 | sudo systemctl reload nginx # restart Nginx service to load the new config
56 | sudo systemctl status nginx # check the status of the Nginx service. It should be active (running).
57 | ```
58 |
--------------------------------------------------------------------------------
/llava/serve/gateway/nginx.conf:
--------------------------------------------------------------------------------
1 | user www-data;
2 | worker_processes auto;
3 | pid /run/nginx.pid;
4 | include /etc/nginx/modules-enabled/*.conf;
5 |
6 | events {
7 | worker_connections 1024; # maximum number of connections that a worker process can handle concurrently
8 | # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle
9 |
10 | }
11 |
12 | http {
13 | ##
14 | # Basic Settings
15 | ##
16 |
17 | sendfile on; # enable sendfile for performance optimization
18 | tcp_nopush on; # enable TCP no-pushing
19 | tcp_nodelay on; # enable TCP no-delay
20 | keepalive_timeout 65; # sets the timeout for keep-alive connections
21 | types_hash_max_size 2048; # maximum size of the types hash table
22 | # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security
23 |
24 | # server_names_hash_bucket_size 64;
25 | # server_name_in_redirect off;
26 |
27 | include /etc/nginx/mime.types; # include MIME types file
28 | default_type application/octet-stream; # default MIME type for unknown file types
29 |
30 | ##
31 | # SSL Settings
32 | ##
33 |
34 | ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use
35 | ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers
36 |
37 | ##
38 | # Logging Settings
39 | ##
40 |
41 | access_log /var/log/nginx/access.log; # path to access log file
42 | error_log /var/log/nginx/error.log; # path to error log file
43 |
44 | ##
45 | # Gzip Settings
46 | ##
47 | gzip on; # enable Gzip compression
48 |
49 | ##
50 | # Virtual Host Configs
51 | ##
52 |
53 | include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory
54 | include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files
55 |
56 | # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/
57 | map $http_upgrade $connection_upgrade {
58 | default upgrade;
59 | '' close;
60 | }
61 |
62 | upstream websocket {
63 | ip_hash; # load balancing by IP to guarantee session persistence
64 | server localhost:7860; # The port should be the gradio web server port
65 | # server localhost:7861; # extra gradio server if more than one
66 | }
67 |
68 | limit_conn_status 429;
69 | limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP
70 | limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server
71 |
72 | server {
73 | listen 443 ssl; # the listening port of our server
74 | ssl_certificate [PATH_TO_SSL_CERT];
75 | ssl_certificate_key [PATH_TO_PRIVATE_KEY];
76 | server_name chat.lmsys.org; # replace the url with your own domain url
77 | limit_conn perserver 1024; # connections per server
78 | location / {
79 | proxy_pass http://websocket; # proxy all requests to the defined upstream server
80 | limit_conn perip 5; # connections per IP
81 | proxy_set_header Host $host; # set the Host header for the upstream server
82 | proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server
83 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header
84 | proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication
85 | proxy_set_header Upgrade $http_upgrade;
86 | proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication
87 | }
88 | }
89 |
90 | # the following block routes all HTTP traffic to HTTPS via nginx
91 | server {
92 | listen 80;
93 | server_name chat.lmsys.org;
94 | return 301 https://chat.lmsys.org$request_uri;
95 | }
96 |
97 | }
98 |
--------------------------------------------------------------------------------
/llava/serve/gradio_css.py:
--------------------------------------------------------------------------------
1 | code_highlight_css = (
2 | """
3 | #chatbot .hll { background-color: #ffffcc }
4 | #chatbot .c { color: #408080; font-style: italic }
5 | #chatbot .err { border: 1px solid #FF0000 }
6 | #chatbot .k { color: #008000; font-weight: bold }
7 | #chatbot .o { color: #666666 }
8 | #chatbot .ch { color: #408080; font-style: italic }
9 | #chatbot .cm { color: #408080; font-style: italic }
10 | #chatbot .cp { color: #BC7A00 }
11 | #chatbot .cpf { color: #408080; font-style: italic }
12 | #chatbot .c1 { color: #408080; font-style: italic }
13 | #chatbot .cs { color: #408080; font-style: italic }
14 | #chatbot .gd { color: #A00000 }
15 | #chatbot .ge { font-style: italic }
16 | #chatbot .gr { color: #FF0000 }
17 | #chatbot .gh { color: #000080; font-weight: bold }
18 | #chatbot .gi { color: #00A000 }
19 | #chatbot .go { color: #888888 }
20 | #chatbot .gp { color: #000080; font-weight: bold }
21 | #chatbot .gs { font-weight: bold }
22 | #chatbot .gu { color: #800080; font-weight: bold }
23 | #chatbot .gt { color: #0044DD }
24 | #chatbot .kc { color: #008000; font-weight: bold }
25 | #chatbot .kd { color: #008000; font-weight: bold }
26 | #chatbot .kn { color: #008000; font-weight: bold }
27 | #chatbot .kp { color: #008000 }
28 | #chatbot .kr { color: #008000; font-weight: bold }
29 | #chatbot .kt { color: #B00040 }
30 | #chatbot .m { color: #666666 }
31 | #chatbot .s { color: #BA2121 }
32 | #chatbot .na { color: #7D9029 }
33 | #chatbot .nb { color: #008000 }
34 | #chatbot .nc { color: #0000FF; font-weight: bold }
35 | #chatbot .no { color: #880000 }
36 | #chatbot .nd { color: #AA22FF }
37 | #chatbot .ni { color: #999999; font-weight: bold }
38 | #chatbot .ne { color: #D2413A; font-weight: bold }
39 | #chatbot .nf { color: #0000FF }
40 | #chatbot .nl { color: #A0A000 }
41 | #chatbot .nn { color: #0000FF; font-weight: bold }
42 | #chatbot .nt { color: #008000; font-weight: bold }
43 | #chatbot .nv { color: #19177C }
44 | #chatbot .ow { color: #AA22FF; font-weight: bold }
45 | #chatbot .w { color: #bbbbbb }
46 | #chatbot .mb { color: #666666 }
47 | #chatbot .mf { color: #666666 }
48 | #chatbot .mh { color: #666666 }
49 | #chatbot .mi { color: #666666 }
50 | #chatbot .mo { color: #666666 }
51 | #chatbot .sa { color: #BA2121 }
52 | #chatbot .sb { color: #BA2121 }
53 | #chatbot .sc { color: #BA2121 }
54 | #chatbot .dl { color: #BA2121 }
55 | #chatbot .sd { color: #BA2121; font-style: italic }
56 | #chatbot .s2 { color: #BA2121 }
57 | #chatbot .se { color: #BB6622; font-weight: bold }
58 | #chatbot .sh { color: #BA2121 }
59 | #chatbot .si { color: #BB6688; font-weight: bold }
60 | #chatbot .sx { color: #008000 }
61 | #chatbot .sr { color: #BB6688 }
62 | #chatbot .s1 { color: #BA2121 }
63 | #chatbot .ss { color: #19177C }
64 | #chatbot .bp { color: #008000 }
65 | #chatbot .fm { color: #0000FF }
66 | #chatbot .vc { color: #19177C }
67 | #chatbot .vg { color: #19177C }
68 | #chatbot .vi { color: #19177C }
69 | #chatbot .vm { color: #19177C }
70 | #chatbot .il { color: #666666 }
71 | """)
72 | #.highlight { background: #f8f8f8; }
73 |
74 |
--------------------------------------------------------------------------------
/llava/serve/gradio_patch.py:
--------------------------------------------------------------------------------
1 | """
2 | Adopted from https://github.com/gradio-app/gradio/blob/main/gradio/components.py
3 | Fix a markdown render problem.
4 | """
5 | from __future__ import annotations
6 |
7 | from gradio.components import *
8 | from markdown2 import Markdown
9 |
10 |
11 | class _Keywords(Enum):
12 | NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()`
13 | FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state)
14 |
15 |
16 | @document("style")
17 | class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable):
18 | """
19 | Displays a chatbot output showing both user submitted messages and responses. Supports a subset of Markdown including bold, italics, code, and images.
20 | Preprocessing: this component does *not* accept input.
21 | Postprocessing: expects function to return a {List[Tuple[str | None | Tuple, str | None | Tuple]]}, a list of tuples with user message and response messages. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed.
22 |
23 | Demos: chatbot_simple, chatbot_multimodal
24 | """
25 |
26 | def __init__(
27 | self,
28 | value: List[Tuple[str | None, str | None]] | Callable | None = None,
29 | color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style()
30 | *,
31 | label: str | None = None,
32 | every: float | None = None,
33 | show_label: bool = True,
34 | visible: bool = True,
35 | elem_id: str | None = None,
36 | elem_classes: List[str] | str | None = None,
37 | **kwargs,
38 | ):
39 | """
40 | Parameters:
41 | value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component.
42 | label: component name in interface.
43 | every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
44 | show_label: if True, will display label.
45 | visible: If False, component will be hidden.
46 | elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles.
47 | elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles.
48 | """
49 | if color_map is not None:
50 | warnings.warn(
51 | "The 'color_map' parameter has been deprecated.",
52 | )
53 | #self.md = utils.get_markdown_parser()
54 | self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"])
55 | self.select: EventListenerMethod
56 | """
57 | Event listener for when the user selects message from Chatbot.
58 | Uses event data gradio.SelectData to carry `value` referring to text of selected message, and `index` tuple to refer to [message, participant] index.
59 | See EventData documentation on how to use this event data.
60 | """
61 |
62 | IOComponent.__init__(
63 | self,
64 | label=label,
65 | every=every,
66 | show_label=show_label,
67 | visible=visible,
68 | elem_id=elem_id,
69 | elem_classes=elem_classes,
70 | value=value,
71 | **kwargs,
72 | )
73 |
74 | def get_config(self):
75 | return {
76 | "value": self.value,
77 | "selectable": self.selectable,
78 | **IOComponent.get_config(self),
79 | }
80 |
81 | @staticmethod
82 | def update(
83 | value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE,
84 | label: str | None = None,
85 | show_label: bool | None = None,
86 | visible: bool | None = None,
87 | ):
88 | updated_config = {
89 | "label": label,
90 | "show_label": show_label,
91 | "visible": visible,
92 | "value": value,
93 | "__type__": "update",
94 | }
95 | return updated_config
96 |
97 | def _process_chat_messages(
98 | self, chat_message: str | Tuple | List | Dict | None
99 | ) -> str | Dict | None:
100 | if chat_message is None:
101 | return None
102 | elif isinstance(chat_message, (tuple, list)):
103 | mime_type = processing_utils.get_mimetype(chat_message[0])
104 | return {
105 | "name": chat_message[0],
106 | "mime_type": mime_type,
107 | "alt_text": chat_message[1] if len(chat_message) > 1 else None,
108 | "data": None, # These last two fields are filled in by the frontend
109 | "is_file": True,
110 | }
111 | elif isinstance(
112 | chat_message, dict
113 | ): # This happens for previously processed messages
114 | return chat_message
115 | elif isinstance(chat_message, str):
116 | #return self.md.render(chat_message)
117 | return str(self.md.convert(chat_message))
118 | else:
119 | raise ValueError(f"Invalid message for Chatbot component: {chat_message}")
120 |
121 | def postprocess(
122 | self,
123 | y: List[
124 | Tuple[str | Tuple | List | Dict | None, str | Tuple | List | Dict | None]
125 | ],
126 | ) -> List[Tuple[str | Dict | None, str | Dict | None]]:
127 | """
128 | Parameters:
129 | y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed.
130 | Returns:
131 | List of tuples representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information.
132 | """
133 | if y is None:
134 | return []
135 | processed_messages = []
136 | for message_pair in y:
137 | assert isinstance(
138 | message_pair, (tuple, list)
139 | ), f"Expected a list of lists or list of tuples. Received: {message_pair}"
140 | assert (
141 | len(message_pair) == 2
142 | ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}"
143 | processed_messages.append(
144 | (
145 | #self._process_chat_messages(message_pair[0]),
146 | '' +
147 | message_pair[0] + "
",
148 | self._process_chat_messages(message_pair[1]),
149 | )
150 | )
151 | return processed_messages
152 |
153 | def style(self, height: int | None = None, **kwargs):
154 | """
155 | This method can be used to change the appearance of the Chatbot component.
156 | """
157 | if height is not None:
158 | self._style["height"] = height
159 | if kwargs.get("color_map") is not None:
160 | warnings.warn("The 'color_map' parameter has been deprecated.")
161 |
162 | Component.style(
163 | self,
164 | **kwargs,
165 | )
166 | return self
167 |
168 |
169 |
--------------------------------------------------------------------------------
/llava/serve/register_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | Manually register workers.
3 |
4 | Usage:
5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6 | """
7 |
8 | import argparse
9 |
10 | import requests
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--controller-address", type=str)
15 | parser.add_argument("--worker-name", type=str)
16 | parser.add_argument("--check-heart-beat", action="store_true")
17 | args = parser.parse_args()
18 |
19 | url = args.controller_address + "/register_worker"
20 | data = {
21 | "worker_name": args.worker_name,
22 | "check_heart_beat": args.check_heart_beat,
23 | "worker_status": None,
24 | }
25 | r = requests.post(url, json=data)
26 | assert r.status_code == 200
27 |
--------------------------------------------------------------------------------
/llava/serve/test_message.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | import requests
5 |
6 | from llava.conversation import default_conversation
7 |
8 |
9 | def main():
10 | if args.worker_address:
11 | worker_addr = args.worker_address
12 | else:
13 | controller_addr = args.controller_address
14 | ret = requests.post(controller_addr + "/refresh_all_workers")
15 | ret = requests.post(controller_addr + "/list_models")
16 | models = ret.json()["models"]
17 | models.sort()
18 | print(f"Models: {models}")
19 |
20 | ret = requests.post(controller_addr + "/get_worker_address",
21 | json={"model": args.model_name})
22 | worker_addr = ret.json()["address"]
23 | print(f"worker_addr: {worker_addr}")
24 |
25 | if worker_addr == "":
26 | return
27 |
28 | conv = default_conversation.copy()
29 | conv.append_message(conv.roles[0], args.message)
30 | prompt = conv.get_prompt()
31 |
32 | headers = {"User-Agent": "LLaVA Client"}
33 | pload = {
34 | "model": args.model_name,
35 | "prompt": prompt,
36 | "max_new_tokens": args.max_new_tokens,
37 | "temperature": 0.7,
38 | "stop": conv.sep,
39 | }
40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41 | json=pload, stream=True)
42 |
43 | print(prompt.replace(conv.sep, "\n"), end="")
44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45 | if chunk:
46 | data = json.loads(chunk.decode("utf-8"))
47 | output = data["text"].split(conv.sep)[-1]
48 | print(output, end="\r")
49 | print("")
50 |
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55 | parser.add_argument("--worker-address", type=str)
56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57 | parser.add_argument("--max-new-tokens", type=int, default=32)
58 | parser.add_argument("--message", type=str, default=
59 | "Tell me a story with more than 1000 words.")
60 | args = parser.parse_args()
61 |
62 | main()
63 |
--------------------------------------------------------------------------------
/llava/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | from typing import List, Optional, Tuple
3 |
4 | import torch
5 | from torch import nn
6 |
7 | import transformers
8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
9 |
10 | from einops import rearrange
11 |
12 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 | def forward(
16 | self,
17 | hidden_states: torch.Tensor,
18 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
19 | attention_mask: Optional[torch.Tensor] = None,
20 | output_attentions: bool = False,
21 | use_cache: bool = False,
22 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
23 | Optional[Tuple[torch.Tensor]]]:
24 | """Input shape: Batch x Time x Channel
25 |
26 | attention_mask: [bsz, q_len]
27 | """
28 | bsz, q_len, _ = hidden_states.size()
29 |
30 | query_states = self.q_proj(hidden_states).view(
31 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
32 | key_states = self.k_proj(hidden_states).view(
33 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
34 | value_states = self.v_proj(hidden_states).view(
35 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
36 | # [bsz, q_len, nh, hd]
37 | # [bsz, nh, q_len, hd]
38 |
39 | kv_seq_len = key_states.shape[-2]
40 | offset = 0
41 | if past_key_value is not None:
42 | offset = past_key_value[0].shape[-2]
43 | kv_seq_len += offset
44 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
45 | query_states, key_states = apply_rotary_pos_emb(query_states,
46 | key_states,
47 | cos,
48 | sin,
49 | offset=offset)
50 | # [bsz, nh, t, hd]
51 | assert not output_attentions, "output_attentions is not supported"
52 | assert not use_cache, "use_cache is not supported"
53 | assert past_key_value is None, "past_key_value is not supported"
54 |
55 | # Flash attention codes from
56 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
57 |
58 | # transform the data into the format required by flash attention
59 | qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd]
60 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
61 | # We have disabled _prepare_decoder_attention_mask in LlamaModel
62 | # the attention_mask should be the same as the key_padding_mask
63 | key_padding_mask = attention_mask
64 |
65 |
66 | if key_padding_mask is None:
67 | qkv = rearrange(qkv, 'b s ... -> (b s) ...')
68 | max_s = q_len
69 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32,
70 | device=qkv.device)
71 | output = flash_attn_unpadded_qkvpacked_func(
72 | qkv, cu_q_lens, max_s, 0.0,
73 | softmax_scale=None, causal=True
74 | )
75 | output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
76 | else:
77 | nheads = qkv.shape[-2]
78 | x = rearrange(qkv, 'b s three h d -> b s (three h d)')
79 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
80 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
81 | output_unpad = flash_attn_unpadded_qkvpacked_func(
82 | x_unpad, cu_q_lens, max_s, 0.0,
83 | softmax_scale=None, causal=True
84 | )
85 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
86 | indices, bsz, q_len),
87 | 'b s (h d) -> b s h d', h=nheads)
88 | return self.o_proj(rearrange(output,
89 | 'b s h d -> b s (h d)')), None, None
90 |
91 |
92 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
93 | # requires the attention mask to be the same as the key_padding_mask
94 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
95 | inputs_embeds, past_key_values_length):
96 | # [bsz, seq_len]
97 | return attention_mask
98 |
99 |
100 | def replace_llama_attn_with_flash_attn():
101 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
102 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
103 |
--------------------------------------------------------------------------------
/llava/train/llava_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | from transformers import Trainer
6 | from typing import Dict, Optional, Sequence
7 |
8 |
9 | def unwrap_model(model: nn.Module) -> nn.Module:
10 | """
11 | Recursively unwraps a model from potential containers (as used in distributed training).
12 |
13 | Args:
14 | model (`torch.nn.Module`): The model to unwrap.
15 | """
16 | # since there could be multiple levels of wrapping, unwrap recursively
17 | if hasattr(model, "module"):
18 | return unwrap_model(model.module)
19 | else:
20 | return model
21 |
22 |
23 | class LLaVATrainer(Trainer):
24 |
25 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
26 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
27 | # Save the model
28 | _state_dict = state_dict
29 | if _state_dict is None:
30 | # Only save the model itself if we are using distributed training
31 | model_to_save = unwrap_model(self.model)
32 | _state_dict = model_to_save.state_dict()
33 |
34 | weight_to_save = {}
35 | keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
36 | for k, v in _state_dict.items():
37 | if any(key_match in k for key_match in keys_to_match):
38 | weight_to_save[k] = v
39 |
40 | current_folder = output_dir.split('/')[-1]
41 | parent_folder = os.path.dirname(output_dir)
42 | if current_folder.startswith('checkpoint-'):
43 | mm_projector_folder = os.path.join(parent_folder, "mm_projector")
44 | os.makedirs(mm_projector_folder, exist_ok=True)
45 | torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
46 | else:
47 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
48 |
49 | super(LLaVATrainer, self)._save(output_dir, state_dict)
50 |
--------------------------------------------------------------------------------
/llava/train/train_mem.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4 |
5 | # Need to call this before importing transformers.
6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7 |
8 | replace_llama_attn_with_flash_attn()
9 |
10 | from llava.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/llava/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from llava.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True)
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/models/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/models/.gitkeep
--------------------------------------------------------------------------------
/original_images/CT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/original_images/CT.png
--------------------------------------------------------------------------------
/original_images/waterview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/original_images/waterview.jpg
--------------------------------------------------------------------------------
/pandagpt/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/.DS_Store
--------------------------------------------------------------------------------
/pandagpt/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/.gitkeep
--------------------------------------------------------------------------------
/pandagpt/code/assets/audios/bird_audio.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/audios/bird_audio.wav
--------------------------------------------------------------------------------
/pandagpt/code/assets/audios/car_audio.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/audios/car_audio.wav
--------------------------------------------------------------------------------
/pandagpt/code/assets/audios/dog_audio.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/audios/dog_audio.wav
--------------------------------------------------------------------------------
/pandagpt/code/assets/images/bird_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/images/bird_image.jpg
--------------------------------------------------------------------------------
/pandagpt/code/assets/images/car_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/images/car_image.jpg
--------------------------------------------------------------------------------
/pandagpt/code/assets/images/dog_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/images/dog_image.jpg
--------------------------------------------------------------------------------
/pandagpt/code/assets/thermals/190662.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/thermals/190662.jpg
--------------------------------------------------------------------------------
/pandagpt/code/assets/thermals/210009.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/thermals/210009.jpg
--------------------------------------------------------------------------------
/pandagpt/code/assets/videos/a.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/videos/a.mp4
--------------------------------------------------------------------------------
/pandagpt/code/assets/videos/world.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/videos/world.mp4
--------------------------------------------------------------------------------
/pandagpt/code/config/__init__.py:
--------------------------------------------------------------------------------
1 | import yaml
2 |
3 | def load_model_config(model, mode):
4 | # load special config for each model
5 | config_path = f'config/{model}.yaml'
6 | print(f'[!] load configuration from {config_path}')
7 | with open(config_path) as f:
8 | configuration = yaml.load(f, Loader=yaml.FullLoader)
9 | new_config = {}
10 | for key, value in configuration.items():
11 | if key in ['train', 'test', 'validation']:
12 | if mode == key:
13 | new_config.update(value)
14 | else:
15 | new_config[key] = value
16 | configuration = new_config
17 | return configuration
18 |
19 | def load_config(args):
20 | '''the configuration of each model can rewrite the base configuration'''
21 | # base config
22 | base_configuration = load_base_config()
23 |
24 | # load one model config
25 | configuration = load_model_config(args['model'], args['mode'])
26 |
27 | # update and append the special config for base config
28 | base_configuration.update(configuration)
29 | configuration = base_configuration
30 | return configuration
31 |
32 | def load_base_config():
33 | config_path = f'config/base.yaml'
34 | with open(config_path) as f:
35 | configuration = yaml.load(f, Loader=yaml.FullLoader)
36 | print(f'[!] load base configuration: {config_path}')
37 | return configuration
38 |
--------------------------------------------------------------------------------
/pandagpt/code/config/base.yaml:
--------------------------------------------------------------------------------
1 | models:
2 | openllama:
3 | model_name: OpenLLAMAModel
4 | agent_name: DeepSpeedAgent
5 | stage1_train_dataset: SupervisedDataset
6 | test_dataset: SelfInstructTestDataset
7 | openllama_peft:
8 | model_name: OpenLLAMAPEFTModel
9 | agent_name: DeepSpeedAgent
10 | stage1_train_dataset: SupervisedDataset
11 | test_dataset: SelfInstructTestDataset
12 |
13 | # ========= Global configuration ========== #
14 | logging_step: 5
15 | # ========= Global configuration ========== #
16 |
--------------------------------------------------------------------------------
/pandagpt/code/config/openllama_peft.yaml:
--------------------------------------------------------------------------------
1 | # generation hyper-parameters
2 | max_len: 512
3 | penalty_alpha: 0.6
4 | top_k: 10
5 | top_p: 0.7
6 | random_prefix_len: 5
7 | sample_num: 2
8 | decoding_method: sampling
9 | generate_len: 512
10 |
11 | # lora hyper-parameters
12 | lora_r: 32
13 | lora_alpha: 32
14 | lora_dropout: 0.1
15 |
16 | # some train configuration, more can be found under dsconfig folder
17 | train:
18 | seed: 0
19 | warmup_rate: 0.1
20 | epochs: 2
21 | max_length: 1024
22 | max_shard_size: 10GB
23 |
--------------------------------------------------------------------------------
/pandagpt/code/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from header import *
2 | from .samplers import DistributedBatchSampler
3 | from .sft_dataset import *
4 |
5 | '''
6 | def get_tokenizer(model):
7 | tokenizer = LlamaTokenizer.from_pretrained(model)
8 | tokenizer.bos_token_id, tokenizer.eos_token_id = 1, 2
9 | tokenizer.pad_token = tokenizer.eos_token
10 | return tokenizer
11 | '''
12 |
13 | def load_sft_dataset(args):
14 | '''
15 | tokenizer = get_tokenizer(args['model_path'])
16 | dataset_name = args['models'][args['model']]['stage1_train_dataset'] # SupervisedDataset, str
17 | data_path = args["data_path"]
18 | data = globals()[dataset_name](data_path, tokenizer, args['max_length']) #SupervisedDataset
19 | '''
20 | data = SupervisedDataset(args['data_path'], args['image_root_path'])
21 |
22 | sampler = torch.utils.data.RandomSampler(data)
23 | world_size = torch.distributed.get_world_size()
24 | rank = torch.distributed.get_rank()
25 | batch_size = args['world_size'] * args['dschf'].config['train_micro_batch_size_per_gpu']
26 | batch_sampler = DistributedBatchSampler(
27 | sampler,
28 | batch_size,
29 | True,
30 | rank,
31 | world_size
32 | )
33 | iter_ = DataLoader(
34 | data,
35 | batch_sampler=batch_sampler,
36 | num_workers=1,
37 | collate_fn=data.collate,
38 | pin_memory=True
39 | )
40 | return data, iter_, sampler
41 |
--------------------------------------------------------------------------------
/pandagpt/code/datasets/samplers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """batch samplers that work with either random or sequential data samplers"""
16 | import math
17 | import os
18 | import sys
19 |
20 | import torch
21 | from torch.utils import data
22 | import numpy as np
23 |
24 |
25 | class RandomSampler(data.sampler.Sampler):
26 | r"""
27 | Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler,
28 | but this class lets the user set an epoch like DistributedSampler
29 | Samples elements randomly. If without replacement, then sample from a shuffled dataset.
30 | If with replacement, then user can specify ``num_samples`` to draw.
31 | Arguments:
32 | data_source (Dataset): dataset to sample from
33 | num_samples (int): number of samples to draw, default=len(dataset)
34 | replacement (bool): samples are drawn with replacement if ``True``, default=False
35 | """
36 |
37 | def __init__(self, data_source, replacement=False, num_samples=None):
38 | super(RandomSampler, self).__init__(data_source)
39 | self.data_source = data_source
40 | self.replacement = replacement
41 | self._num_samples = num_samples
42 | self.epoch = -1
43 |
44 | if self._num_samples is not None and replacement is False:
45 | raise ValueError("With replacement=False, num_samples should not be specified, "
46 | "since a random permute will be performed.")
47 |
48 | if not isinstance(self.num_samples, int) or self.num_samples <= 0:
49 | raise ValueError("num_samples should be a positive integer "
50 | "value, but got num_samples={}".format(self.num_samples))
51 | if not isinstance(self.replacement, bool):
52 | raise ValueError("replacement should be a boolean value, but got "
53 | "replacement={}".format(self.replacement))
54 |
55 | @property
56 | def num_samples(self):
57 | # dataset size might change at runtime
58 | if self._num_samples is None:
59 | return len(self.data_source)
60 | return self._num_samples
61 |
62 | def __iter__(self):
63 | n = len(self.data_source)
64 | g = torch.Generator()
65 | if self.epoch >= 0:
66 | g.manual_seed(self.epoch)
67 | if self.replacement:
68 | for _ in range(self.num_samples // 32):
69 | yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=g).tolist()
70 | yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64,
71 | generator=g).tolist()
72 | else:
73 | yield from torch.randperm(n, generator=self.generator).tolist()
74 |
75 | def __len__(self):
76 | return self.num_samples
77 |
78 | def set_epoch(self, epoch):
79 | self.epoch = epoch
80 |
81 |
82 | class DistributedSequentialSampler(data.sampler.Sampler):
83 | def __init__(self, num_samples, train_iters, batch_size, rank=-1, world_size=2):
84 | super().__init__(num_samples)
85 | if rank == -1:
86 | rank = 0
87 | world_size = 1
88 | self.num_samples = num_samples
89 | self.rank = rank
90 | self.world_size = world_size
91 | self.start_iter = 0
92 | self.train_iters = train_iters
93 | self.batch_size = batch_size
94 | self.batch_bias = [i * (num_samples // batch_size) for i in range(batch_size)]
95 |
96 | def __iter__(self):
97 | for idx in range(self.start_iter, self.train_iters * 10):
98 | batch = [(idx + bias) % self.num_samples for bias in self.batch_bias]
99 | tbatch = self._batch(batch)
100 | yield tbatch
101 |
102 | def __len__(self):
103 | return self.train_iters
104 |
105 | def _batch(self, batch):
106 | """extracts samples only pertaining to this worker's batch"""
107 | start = self.rank*self.batch_size//self.world_size
108 | end = (self.rank+1)*self.batch_size//self.world_size
109 | return batch[start:end]
110 |
111 |
112 | class DistributedBatchSampler(data.sampler.BatchSampler):
113 | """
114 | similar to normal implementation of distributed sampler, except implementation is at the
115 | batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary
116 | data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler.
117 | """
118 | def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False, gradient_accumulation_steps=None):
119 | super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last)
120 | if rank == -1:
121 | assert False, 'should not be here'
122 | self.rank = rank
123 | self.world_size = world_size
124 | self.sampler.wrap_around = 0
125 | self.wrap_around = 0
126 | self.wrap_last = wrap_last
127 | self.start_iter = 0
128 | self.effective_batch_size = batch_size if gradient_accumulation_steps is None else batch_size * gradient_accumulation_steps
129 |
130 | def __iter__(self):
131 | batch = []
132 | i = 0
133 | for idx in self.data_iterator(self.sampler, wrap_around=False):
134 | batch.append(idx)
135 | if len(batch) == self.batch_size:
136 | tbatch = self._batch(batch)
137 | if i >= self.start_iter * self.effective_batch_size:
138 | yield tbatch
139 | self.start_iter = 0
140 | i += len(batch)
141 | batch = []
142 | batch_len = len(batch)
143 | if batch_len > 0 and not self.drop_last:
144 | if self.wrap_last:
145 | self.sampler.wrap_around -= (self.batch_size)
146 | self.wrap_around += (len(batch))
147 | self.wrap_around %= self.batch_size
148 | yield self._batch(batch)
149 | if self.wrap_last:
150 | self.sampler.wrap_around += self.batch_size
151 |
152 | def data_iterator(self, _iter, wrap_around=False):
153 | """iterates through data and handles wrap around"""
154 | for i, idx in enumerate(_iter):
155 | if i < self.wrap_around%self.batch_size:
156 | continue
157 | if wrap_around:
158 | self.wrap_around += 1
159 | self.wrap_around %= self.batch_size
160 | yield idx
161 |
162 | def _batch(self, batch):
163 | """extracts samples only pertaining to this worker's batch"""
164 | start = self.rank*self.batch_size//self.world_size
165 | end = (self.rank+1)*self.batch_size//self.world_size
166 | return batch[start:end]
167 |
--------------------------------------------------------------------------------
/pandagpt/code/datasets/sft_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import copy
16 | import os
17 | import json
18 | from tqdm import tqdm
19 | import ipdb
20 | import random
21 | from torch.nn.utils.rnn import pad_sequence
22 | from dataclasses import dataclass, field
23 | from typing import Callable, Dict, Sequence
24 |
25 | import torch
26 | import torch.distributed as dist
27 | import transformers
28 | from torch.utils.data import Dataset
29 | from tqdm import tqdm
30 |
31 | class SupervisedDataset(Dataset):
32 | """Dataset for supervised fine-tuning."""
33 |
34 | def __init__(self, data_path: str, image_root_path: str):
35 | super(SupervisedDataset, self).__init__()
36 |
37 | with open(data_path, 'r') as f:
38 | json_data = json.load(f)
39 | # for debug:
40 | #json_data = json_data[:100000]
41 |
42 | self.image_path_list, self.caption_list = [], []
43 | for item in json_data:
44 | one_image_name, one_caption = item["image_name"], item["conversation"]
45 | # TODO: stage 2 dataset format is invalid
46 | if not one_image_name.endswith('.jpg'):
47 | one_image_name += '.jpg'
48 | one_image_path = image_root_path + '/{}'.format(one_image_name)
49 | self.image_path_list.append(one_image_path)
50 | self.caption_list.append(one_caption)
51 | print(f'[!] collect {len(self.image_path_list)} samples for training')
52 |
53 | def __len__(self): # number of instances
54 | return len(self.image_path_list)
55 |
56 | #def __getitem__(self, i) -> Dict[str, torch.Tensor]: # how to get item, 取一个样本
57 | def __getitem__(self, i):
58 | return dict(image_paths=self.image_path_list[i], output_texts=self.caption_list[i])
59 |
60 | def collate(self, instances):
61 | image_paths, output_texts = tuple([instance[key] for instance in instances] for key in ("image_paths", "output_texts"))
62 | return dict(
63 | image_paths=image_paths,
64 | output_texts=output_texts
65 | )
66 |
--------------------------------------------------------------------------------
/pandagpt/code/dsconfig/openllama_peft_stage_1.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_batch_size": 64,
3 | "train_micro_batch_size_per_gpu": 1,
4 | "gradient_accumulation_steps": 8,
5 | "steps_per_print": 1,
6 | "gradient_clipping": 1.0,
7 | "zero_optimization": {
8 | "stage": 2,
9 | "offload_optimizer": {
10 | "device": "cpu"
11 | },
12 | "contiguous_gradients": true,
13 | "allgather_bucket_size": 500000000,
14 | "allgather_partitions": true
15 | },
16 | "fp16": {
17 | "enabled": true,
18 | "opt_level": "O2",
19 | "min_loss_scale": 1
20 | },
21 | "bf16": {
22 | "enable": true
23 | },
24 | "optimizer": {
25 | "type": "Adam",
26 | "params": {
27 | "lr": 0.0005,
28 | "betas": [
29 | 0.9,
30 | 0.95
31 | ],
32 | "eps": 1e-8,
33 | "weight_decay": 0.001
34 | }
35 | },
36 | "scheduler": {
37 | "type": "WarmupDecayLR",
38 | "params": {
39 | "warmup_min_lr": 0,
40 | "warmup_max_lr": 0.0005,
41 | "warmup_num_steps": 10,
42 | "total_num_steps": 10000
43 | }
44 | },
45 | "activation_checkpointing": {
46 | "partition_activations": true,
47 | "cpu_checkpointing": true,
48 | "contiguous_memory_optimization": false,
49 | "number_checkpoints": null,
50 | "synchronize_checkpoint_boundary": false,
51 | "profile": false
52 | }
53 |
54 | }
--------------------------------------------------------------------------------
/pandagpt/code/header.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import datetime
3 | import types
4 | import deepspeed
5 | from transformers.deepspeed import HfDeepSpeedConfig
6 | import transformers
7 | import numpy as np
8 | from collections import OrderedDict
9 | from torch.utils.data import Dataset, DataLoader
10 | from torch.nn.utils import clip_grad_norm_
11 | from torch.cuda.amp import autocast, GradScaler
12 | from torch.nn import DataParallel
13 | from torch.optim import lr_scheduler
14 | import torch.optim as optim
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from tqdm import tqdm
18 | import os
19 | import re
20 | import math
21 | import random
22 | import json
23 | import time
24 | import logging
25 | from copy import deepcopy
26 | import ipdb
27 | import argparse
28 | import data
29 | from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig
30 | from torch.nn.utils.rnn import pad_sequence
31 | from peft import LoraConfig, TaskType, get_peft_model
32 |
33 | logging.getLogger("transformers").setLevel(logging.WARNING)
34 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)
35 | os.environ['TOKENIZERS_PARALLELISM'] = 'false'
36 |
--------------------------------------------------------------------------------
/pandagpt/code/model/ImageBind/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to make participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | This Code of Conduct also applies outside the project spaces when there is a
56 | reasonable belief that an individual's behavior may have a negative impact on
57 | the project or its community.
58 |
59 | ## Enforcement
60 |
61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
62 | reported by contacting the project team at . All
63 | complaints will be reviewed and investigated and will result in a response that
64 | is deemed necessary and appropriate to the circumstances. The project team is
65 | obligated to maintain confidentiality with regard to the reporter of an incident.
66 | Further details of specific enforcement policies may be posted separately.
67 |
68 | Project maintainers who do not follow or enforce the Code of Conduct in good
69 | faith may face temporary or permanent repercussions as determined by other
70 | members of the project's leadership.
71 |
72 | ## Attribution
73 |
74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76 |
77 | [homepage]: https://www.contributor-covenant.org
78 |
79 | For answers to common questions about this code of conduct, see
80 | https://www.contributor-covenant.org/faq
--------------------------------------------------------------------------------
/pandagpt/code/model/ImageBind/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to ImageBind
2 | We want to make contributing to this project as easy and transparent as
3 | possible.
4 |
5 | ## Pull Requests
6 | We actively welcome your pull requests.
7 |
8 | 1. Fork the repo and create your branch from `main`.
9 | 2. If you've added code that should be tested, add tests.
10 | 3. If you've changed APIs, update the documentation.
11 | 4. Ensure the test suite passes.
12 | 5. Make sure your code lints.
13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14 |
15 | ## Contributor License Agreement ("CLA")
16 | In order to accept your pull request, we need you to submit a CLA. You only need
17 | to do this once to work on any of Meta's open source projects.
18 |
19 | Complete your CLA here:
20 |
21 | ## Issues
22 | We use GitHub issues to track public bugs. Please ensure your description is
23 | clear and has sufficient instructions to be able to reproduce the issue.
24 |
25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26 | disclosure of security bugs. In those cases, please go through the process
27 | outlined on that page and do not file a public issue.
28 |
29 | ## License
30 | By contributing to Omnivore, you agree that your contributions will be licensed
31 | under the [LICENSE](LICENSE) file in the root directory of this source tree.
32 |
--------------------------------------------------------------------------------
/pandagpt/code/model/ImageBind/README.md:
--------------------------------------------------------------------------------
1 | # ImageBind: One Embedding Space To Bind Them All
2 |
3 | **[FAIR, Meta AI](https://ai.facebook.com/research/)**
4 |
5 | Rohit Girdhar*,
6 | Alaaeldin El-Nouby*,
7 | Zhuang Liu,
8 | Mannat Singh,
9 | Kalyan Vasudev Alwala,
10 | Armand Joulin,
11 | Ishan Misra*
12 |
13 | To appear at CVPR 2023 (*Highlighted paper*)
14 |
15 | [[`Paper`](https://facebookresearch.github.io/ImageBind/paper)] [[`Blog`](https://ai.facebook.com/blog/imagebind-six-modalities-binding-ai/)] [[`Demo`](https://imagebind.metademolab.com/)] [[`Supplementary Video`](https://dl.fbaipublicfiles.com/imagebind/imagebind_video.mp4)] [[`BibTex`](#citing-imagebind)]
16 |
17 | PyTorch implementation and pretrained models for ImageBind. For details, see the paper: **[ImageBind: One Embedding Space To Bind Them All](https://facebookresearch.github.io/ImageBind/paper)**.
18 |
19 | ImageBind learns a joint embedding across six different modalities - images, text, audio, depth, thermal, and IMU data. It enables novel emergent applications ‘out-of-the-box’ including cross-modal retrieval, composing modalities with arithmetic, cross-modal detection and generation.
20 |
21 |
22 |
23 | 
24 |
25 | ## ImageBind model
26 |
27 | Emergent zero-shot classification performance.
28 |
29 |
30 |
31 | Model |
32 | IN1k |
33 | K400 |
34 | NYU-D |
35 | ESC |
36 | LLVIP |
37 | Ego4D |
38 | download |
39 |
40 |
41 | imagebind_huge |
42 | 77.7 |
43 | 50.0 |
44 | 54.0 |
45 | 66.9 |
46 | 63.4 |
47 | 25.0 |
48 | checkpoint |
49 |
50 |
51 |
52 |
53 | ## Usage
54 |
55 | Install pytorch 1.13+ and other 3rd party dependencies.
56 |
57 | ```shell
58 | conda create --name imagebind python=3.8 -y
59 | conda activate imagebind
60 |
61 | pip install -r requirements.txt
62 | ```
63 |
64 | For windows users, you might need to install `soundfile` for reading/writing audio files. (Thanks @congyue1977)
65 |
66 | ```
67 | pip install soundfile
68 | ```
69 |
70 |
71 | Extract and compare features across modalities (e.g. Image, Text and Audio).
72 |
73 | ```python
74 | import data
75 | import torch
76 | from models import imagebind_model
77 | from models.imagebind_model import ModalityType
78 |
79 | text_list=["A dog.", "A car", "A bird"]
80 | image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"]
81 | audio_paths=[".assets/dog_audio.wav", ".assets/car_audio.wav", ".assets/bird_audio.wav"]
82 |
83 | device = "cuda:0" if torch.cuda.is_available() else "cpu"
84 |
85 | # Instantiate model
86 | model = imagebind_model.imagebind_huge(pretrained=True)
87 | model.eval()
88 | model.to(device)
89 |
90 | # Load data
91 | inputs = {
92 | ModalityType.TEXT: data.load_and_transform_text(text_list, device),
93 | ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device),
94 | ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device),
95 | }
96 |
97 | with torch.no_grad():
98 | embeddings = model(inputs)
99 |
100 | print(
101 | "Vision x Text: ",
102 | torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1),
103 | )
104 | print(
105 | "Audio x Text: ",
106 | torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1),
107 | )
108 | print(
109 | "Vision x Audio: ",
110 | torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1),
111 | )
112 |
113 | # Expected output:
114 | #
115 | # Vision x Text:
116 | # tensor([[9.9761e-01, 2.3694e-03, 1.8612e-05],
117 | # [3.3836e-05, 9.9994e-01, 2.4118e-05],
118 | # [4.7997e-05, 1.3496e-02, 9.8646e-01]])
119 | #
120 | # Audio x Text:
121 | # tensor([[1., 0., 0.],
122 | # [0., 1., 0.],
123 | # [0., 0., 1.]])
124 | #
125 | # Vision x Audio:
126 | # tensor([[0.8070, 0.1088, 0.0842],
127 | # [0.1036, 0.7884, 0.1079],
128 | # [0.0018, 0.0022, 0.9960]])
129 |
130 | ```
131 |
132 | ## Model card
133 | Please see the [model card](model_card.md) for details.
134 |
135 | ## License
136 |
137 | ImageBind code and model weights are released under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for additional details.
138 |
139 | ## Contributing
140 |
141 | See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
142 |
143 | ## Citing ImageBind
144 |
145 | If you find this repository useful, please consider giving a star :star: and citation
146 |
147 | ```
148 | @inproceedings{girdhar2023imagebind,
149 | title={ImageBind: One Embedding Space To Bind Them All},
150 | author={Girdhar, Rohit and El-Nouby, Alaaeldin and Liu, Zhuang
151 | and Singh, Mannat and Alwala, Kalyan Vasudev and Joulin, Armand and Misra, Ishan},
152 | booktitle={CVPR},
153 | year={2023}
154 | }
155 | ```
156 |
--------------------------------------------------------------------------------
/pandagpt/code/model/ImageBind/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import imagebind_model
2 | from .models.imagebind_model import ModalityType
3 |
--------------------------------------------------------------------------------
/pandagpt/code/model/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/model/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/pandagpt/code/model/ImageBind/model_card.md:
--------------------------------------------------------------------------------
1 | # Model Card for ImageBind
2 |
3 | Multimodal joint embedding model for image/video, text, audio, depth, IMU, and thermal images.
4 | Input any of the six modalities and get the same sized embedding that can be used for cross-modal and multimodal tasks.
5 |
6 | # Model Details
7 |
8 | ## Model Description
9 |
10 |
11 | Multimodal joint embedding model for image/video, text, audio, depth, IMU, and thermal images
12 |
13 | - **Developed by:** Meta AI
14 | - **Model type:** Multimodal model
15 | - **Language(s) (NLP):** en
16 | - **License:** CC BY-NC-SA 4.0
17 | - **Resources for more information:**
18 | - [GitHub Repo](https://github.com/facebookresearch/ImageBind)
19 |
20 |
21 | # Uses
22 |
23 |
24 | This model is intended only for research purposes. It provides a joint embedding space for different modalities -- image/video, text, audio, depth, IMU and thermal images.
25 | We hope that these joint embeddings can be used for a variety of different cross-modal research, e.g., cross-modal retrieval and combining embeddings from different modalities.
26 |
27 | ## Out-of-Scope Use
28 |
29 |
30 |
31 |
32 | This model is *NOT* intended to be used in any real world application -- commercial or otherwise.
33 | It may produce harmful associations with different inputs.
34 | The model needs to be investigated and likely re-trained on specific data for any such application.
35 | The model is expected to work better on web-based visual data since it was trained on such data.
36 | The text encoder is likely to work only on English language text because of the underlying training datasets.
37 |
38 | # Bias, Risks, and Limitations
39 |
40 |
41 | Open-domain joint embedding models are prone to producing specific biases, e.g., study from [CLIP](https://github.com/openai/CLIP/blob/main/model-card.md#bias-and-fairness).
42 | Since our model uses such models as initialization, it will exhibit such biases too.
43 | Moreover, for learning joint embeddings for other modalities such as audio, thermal, depth, and IMU we leverage datasets that are relatively small. These joint embeddings are thus limited to the concepts present in the datasets. For example, the thermal datasets we used are limited to outdoor street scenes, while the depth datasets are limited to indoor scenes.
44 |
45 |
46 |
47 | # Training Details
48 |
49 | ## Training Data
50 |
51 |
52 |
53 | ImageBind uses image-paired data for training -- (image, X) where X is one of text, audio, depth, IMU or thermal data.
54 | In particular, we initialize and freeze the image and text encoders using an OpenCLIP ViT-H encoder.
55 | We train audio embeddings using Audioset, depth embeddings using the SUN RGB-D dataset, IMU using the Ego4D dataset and thermal embeddings using the LLVIP dataset.
56 | We provide the exact training data details in the paper.
57 |
58 |
59 | ## Training Procedure
60 |
61 |
62 | Please refer to the research paper and github repo for exact details on this.
63 |
64 | # Evaluation
65 |
66 | ## Testing Data, Factors & Metrics
67 |
68 | We evaluate the model on a variety of different classification benchmarks for each modality.
69 | The evaluation details are presented in the paper.
70 | The models performance is measured using standard classification metrics such as accuracy and mAP.
71 |
72 | # Citation
73 |
74 |
75 |
76 | **BibTeX:**
77 | ```
78 | @inproceedings{girdhar2023imagebind,
79 | title={ImageBind: One Embedding Space To Bind Them All},
80 | author={Girdhar, Rohit and El-Nouby, Alaaeldin and Liu, Zhuang
81 | and Singh, Mannat and Alwala, Kalyan Vasudev and Joulin, Armand and Misra, Ishan},
82 | booktitle={CVPR},
83 | year={2023}
84 | }
85 | ```
86 |
87 |
88 | # Model Card Contact
89 |
90 | Please reach out to the authors at: rgirdhar@meta.com imisra@meta.com alaaelnouby@gmail.com
91 |
92 | # How to Get Started with the Model
93 |
94 | Our github repo provides a simple example to extract embeddings from images, audio etc.
95 |
--------------------------------------------------------------------------------
/pandagpt/code/model/ImageBind/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/model/ImageBind/models/__init__.py
--------------------------------------------------------------------------------
/pandagpt/code/model/ImageBind/models/helpers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # Portions Copyright (c) Meta Platforms, Inc. and affiliates.
3 | # All rights reserved.
4 |
5 | # This source code is licensed under the license found in the
6 | # LICENSE file in the root directory of this source tree.
7 |
8 | import math
9 |
10 | import einops
11 | import numpy as np
12 | import torch
13 |
14 | import torch.nn as nn
15 |
16 |
17 | class Normalize(nn.Module):
18 | def __init__(self, dim: int) -> None:
19 | super().__init__()
20 | self.dim = dim
21 |
22 | def forward(self, x):
23 | return torch.nn.functional.normalize(x, dim=self.dim, p=2)
24 |
25 |
26 | class LearnableLogitScaling(nn.Module):
27 | def __init__(
28 | self,
29 | logit_scale_init: float = 1 / 0.07,
30 | learnable: bool = True,
31 | max_logit_scale: float = 100,
32 | ) -> None:
33 | super().__init__()
34 | self.max_logit_scale = max_logit_scale
35 | self.logit_scale_init = logit_scale_init
36 | self.learnable = learnable
37 | log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init)
38 | if learnable:
39 | self.log_logit_scale = nn.Parameter(log_logit_scale)
40 | else:
41 | self.register_buffer("log_logit_scale", log_logit_scale)
42 |
43 | def forward(self, x):
44 | return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x
45 |
46 | def extra_repr(self):
47 | st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}"
48 | return st
49 |
50 |
51 | class EinOpsRearrange(nn.Module):
52 | def __init__(self, rearrange_expr: str, **kwargs) -> None:
53 | super().__init__()
54 | self.rearrange_expr = rearrange_expr
55 | self.kwargs = kwargs
56 |
57 | def forward(self, x):
58 | assert isinstance(x, torch.Tensor)
59 | return einops.rearrange(x, self.rearrange_expr, **self.kwargs)
60 |
61 |
62 | class VerboseNNModule(nn.Module):
63 | """
64 | Wrapper around nn.Module that prints registered buffers and parameter names.
65 | """
66 |
67 | @staticmethod
68 | def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str:
69 | st = (
70 | "("
71 | + name
72 | + "): "
73 | + "tensor("
74 | + str(tuple(tensor[1].shape))
75 | + ", requires_grad="
76 | + str(tensor[1].requires_grad)
77 | + ")\n"
78 | )
79 | return st
80 |
81 | def extra_repr(self) -> str:
82 | named_modules = set()
83 | for p in self.named_modules():
84 | named_modules.update([p[0]])
85 | named_modules = list(named_modules)
86 |
87 | string_repr = ""
88 | for p in self.named_parameters():
89 | name = p[0].split(".")[0]
90 | if name not in named_modules:
91 | string_repr += self.get_readable_tensor_repr(name, p)
92 |
93 | for p in self.named_buffers():
94 | name = p[0].split(".")[0]
95 | string_repr += self.get_readable_tensor_repr(name, p)
96 |
97 | return string_repr
98 |
99 |
100 | def cast_if_src_dtype(
101 | tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype
102 | ):
103 | updated = False
104 | if tensor.dtype == src_dtype:
105 | tensor = tensor.to(dtype=tgt_dtype)
106 | updated = True
107 | return tensor, updated
108 |
109 |
110 | class QuickGELU(nn.Module):
111 | # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166
112 | def forward(self, x: torch.Tensor):
113 | return x * torch.sigmoid(1.702 * x)
114 |
115 |
116 | class SelectElement(nn.Module):
117 | def __init__(self, index) -> None:
118 | super().__init__()
119 | self.index = index
120 |
121 | def forward(self, x):
122 | assert x.ndim >= 3
123 | return x[:, self.index, ...]
124 |
125 |
126 | class SelectEOSAndProject(nn.Module):
127 | """
128 | Text Pooling used in OpenCLIP
129 | """
130 |
131 | def __init__(self, proj: nn.Module) -> None:
132 | super().__init__()
133 | self.proj = proj
134 |
135 | def forward(self, x, seq_len):
136 | assert x.ndim == 3
137 | # x is of shape B x L x D
138 | # take features from the eot embedding (eot_token is the highest number in each sequence)
139 | x = x[torch.arange(x.shape[0]), seq_len]
140 | x = self.proj(x)
141 | return x
142 |
--------------------------------------------------------------------------------
/pandagpt/code/model/ImageBind/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu113
2 | torchvision==0.14.0
3 | torchaudio==0.13.0
4 | pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d
5 | timm==0.6.7
6 | ftfy
7 | regex
8 | einops
9 | fvcore
10 | decord==0.6.0
11 |
--------------------------------------------------------------------------------
/pandagpt/code/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .agent import DeepSpeedAgent
2 | from .openllama import OpenLLAMAPEFTModel
3 |
4 | def load_model(args):
5 | agent_name = args['models'][args['model']]['agent_name']
6 | model_name = args['models'][args['model']]['model_name']
7 | model = globals()[model_name](**args)
8 | agent = globals()[agent_name](model, args)
9 | return agent
10 |
--------------------------------------------------------------------------------
/pandagpt/code/model/agent.py:
--------------------------------------------------------------------------------
1 | from header import *
2 |
3 | class DeepSpeedAgent:
4 |
5 | def __init__(self, model, args):
6 | super(DeepSpeedAgent, self).__init__()
7 | self.args = args
8 | self.model = model
9 | if args['stage'] == 2:
10 | self.load_stage_1_parameters(args["delta_ckpt_path"])
11 | print(f'[!] load stage 1 checkpoint from {args["delta_ckpt_path"]}')
12 |
13 | # load config parameters of deepspeed
14 | ds_params = json.load(open(self.args['ds_config_path']))
15 | ds_params['scheduler']['params']['total_num_steps'] = self.args['total_steps']
16 | ds_params['scheduler']['params']['warmup_num_steps'] = max(10, int(self.args['total_steps'] * self.args['warmup_rate']))
17 | self.ds_engine, self.optimizer, _ , _ = deepspeed.initialize(
18 | model=self.model,
19 | model_parameters=self.model.parameters(),
20 | config_params=ds_params,
21 | dist_init_required=True,
22 | args=types.SimpleNamespace(**args)
23 | )
24 |
25 | @torch.no_grad()
26 | def predict(self, batch):
27 | self.model.eval()
28 | string = self.model.generate_one_sample(batch)
29 | return string
30 |
31 | def train_model(self, batch, current_step=0, pbar=None):
32 | self.ds_engine.module.train()
33 | loss, mle_acc = self.ds_engine(batch)
34 |
35 | self.ds_engine.backward(loss)
36 | self.ds_engine.step()
37 | pbar.set_description(f'[!] loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
38 | pbar.update(1)
39 | if self.args['local_rank'] == 0 and self.args['log_path'] and current_step % self.args['logging_step'] == 0:
40 | elapsed = pbar.format_dict['elapsed']
41 | rate = pbar.format_dict['rate']
42 | remaining = (pbar.total - pbar.n) / rate if rate and pbar.total else 0
43 | remaining = str(datetime.timedelta(seconds=remaining))
44 | logging.info(f'[!] progress: {round(pbar.n/pbar.total, 5)}; remaining time: {remaining}; loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}')
45 |
46 | mle_acc *= 100
47 | return mle_acc
48 |
49 | def save_model(self, path, current_step):
50 | # only save trainable model parameters
51 | param_grad_dic = {
52 | k: v.requires_grad for (k, v) in self.ds_engine.module.named_parameters()
53 | }
54 | state_dict = self.ds_engine.module.state_dict()
55 | checkpoint = OrderedDict()
56 | for k, v in self.ds_engine.module.named_parameters():
57 | if v.requires_grad:
58 | checkpoint[k] = v
59 | torch.save(checkpoint, f'{path}/pytorch_model.pt')
60 | # save tokenizer
61 | self.model.llama_tokenizer.save_pretrained(path)
62 | # save configuration
63 | self.model.llama_model.config.save_pretrained(path)
64 | print(f'[!] save model into {path}')
65 |
66 | def load_stage_1_parameters(self, path):
67 | delta_ckpt = torch.load(path, map_location=torch.device('cpu'))
68 | self.model.load_state_dict(delta_ckpt, strict=False)
69 |
--------------------------------------------------------------------------------
/pandagpt/code/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_addr 127.0.0.1 --master_port 28457 train_sft.py \
4 | --model openllama_peft \
5 | --stage 1\
6 | --data_path ../data/pandagpt4_visual_instruction_data.json\
7 | --image_root_path ../data/images/\
8 | --imagebind_ckpt_path ../pretrained_ckpt/imagebind_ckpt/\
9 | --vicuna_ckpt_path ../pretrained_ckpt/vicuna_ckpt/13b_v0/\
10 | --max_tgt_len 400\
11 | --save_path ./ckpt/pandagpt_13b_v0_peft/\
12 | --log_path ./ckpt/pandagpt_13b_v0_peft/log_rest/
13 |
--------------------------------------------------------------------------------
/pandagpt/code/train_sft.py:
--------------------------------------------------------------------------------
1 | from header import *
2 | from datasets import *
3 | from model import *
4 | from config import *
5 |
6 | def parser_args():
7 | parser = argparse.ArgumentParser(description='train parameters')
8 | parser.add_argument('--model', type=str)
9 | parser.add_argument('--data_path', type=str)
10 | parser.add_argument('--local_rank', default=0, type=int)
11 | parser.add_argument('--save_path', type=str)
12 | parser.add_argument('--log_path', type=str)
13 | # model configurations
14 | parser.add_argument('--image_root_path', type=str) # the directory that stores all images
15 | parser.add_argument('--imagebind_ckpt_path', type=str) # the path that stores the imagebind checkpoint
16 | parser.add_argument('--vicuna_ckpt_path', type=str) # the path that stores the vicuna checkpoint
17 | parser.add_argument('--delta_ckpt_path', type=str) # the delta parameters trained in stage 1
18 | parser.add_argument('--max_tgt_len', type=int) # the maximum sequence length
19 | parser.add_argument('--stage', type=int) # the maximum sequence length
20 | return parser.parse_args()
21 |
22 | def initialize_distributed(args):
23 | args['master_ip'] = os.getenv('MASTER_ADDR', 'localhost')
24 | args['master_port'] = os.getenv('MASTER_PORT', '6000')
25 | args['world_size'] = int(os.getenv('WORLD_SIZE', '1'))
26 | args['local_rank'] = int(os.getenv('RANK', '0')) % torch.cuda.device_count()
27 | device = args['local_rank'] % torch.cuda.device_count()
28 | torch.cuda.set_device(device)
29 | deepspeed.init_distributed(dist_backend='nccl')
30 |
31 | def set_random_seed(seed):
32 | if seed is not None and seed > 0:
33 | random.seed(seed)
34 | np.random.seed(seed)
35 | torch.manual_seed(seed)
36 | torch.random.manual_seed(seed)
37 | torch.cuda.manual_seed(seed)
38 | torch.cuda.manual_seed_all(seed)
39 |
40 | def config_env(args):
41 | args['root_dir'] = '../'
42 | args['mode'] = 'train'
43 | config = load_config(args)
44 | args.update(config)
45 | initialize_distributed(args)
46 | set_random_seed(args['seed'])
47 |
48 | def build_directory(path):
49 | if os.path.exists(path):
50 | pass
51 | else: # recursively construct directory
52 | os.makedirs(path, exist_ok=True)
53 |
54 | def main(**args):
55 | config_env(args)
56 | args['ds_config_path'] = f'dsconfig/{args["model"]}_stage_{args["stage"]}.json'
57 | dschf = HfDeepSpeedConfig(args['ds_config_path'])
58 | args['dschf'] = dschf
59 |
60 | build_directory(args['save_path'])
61 | build_directory(args['log_path'])
62 |
63 | if args['log_path']:
64 | logging.basicConfig(
65 | format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',
66 | level=logging.DEBUG,
67 | filename=f'{args["log_path"]}/train_{time.asctime()}.log',
68 | filemode='w'
69 | )
70 |
71 | train_data, train_iter, sampler = load_sft_dataset(args)
72 |
73 | length = args['epochs'] * len(train_data) // args['world_size'] // dschf.config['train_micro_batch_size_per_gpu']
74 | total_steps = args['epochs'] * len(train_data) // dschf.config['train_batch_size']
75 | args['total_steps'] = total_steps
76 | agent = load_model(args)
77 | torch.distributed.barrier()
78 |
79 | # begin to train
80 | pbar = tqdm(total=length) # maximum total number
81 | current_step = 0
82 | for epoch_i in tqdm(range(args['epochs'])):
83 | for batch in train_iter:
84 | agent.train_model(
85 | batch,
86 | current_step=current_step,
87 | pbar=pbar
88 | )
89 | current_step += 1
90 | # save at the end of the training
91 | torch.distributed.barrier()
92 | agent.save_model(args['save_path'], 0)
93 |
94 | if __name__ == "__main__":
95 | args = parser_args()
96 | args = vars(args)
97 | main(**args)
98 |
--------------------------------------------------------------------------------
/pandagpt/data/empty.txt:
--------------------------------------------------------------------------------
1 | empty text
--------------------------------------------------------------------------------
/pandagpt/pretrained_ckpt/README.md:
--------------------------------------------------------------------------------
1 | # 1. Prepare Vicuna Checkpoint:
2 |
3 | The language decoder of PandaGPT is based on Vicuna version 0. Given the distribution license of LLaMA, you need to restore the weights of Vicuna manually. To restore the weights, please follow the instructions below. In the following, we showcase how to restore the 7B version of Vicuna v0. To obtain the 13B version of Vicuna, you can take similar procedures.
4 |
5 | ## 1.1. Obtain LLaMA Weights:
6 | * Request the weights of LLaMA from Meta using [this form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform).
7 | * After obtaining the weights of a specific LLaMA (e.g. 7B, 13B), following [instructions](https://huggingface.co/docs/transformers/main/model_doc/llama) provided by Huggingface to convert it into Huggingface format.
8 |
9 | > **** After conversion, the directory should look like:
10 |
11 | .
12 | └── ./{path_to_llama_weights}/
13 | ├── config.json
14 | ├── generation_config.json
15 | ├── pytorch_model-00001-of-00002.bin
16 | ├── pytorch_model-00002-of-00002.bin
17 | ├── pytorch_model.bin.index.json
18 | ├── special_tokens_map.json
19 | ├── tokenizer.model
20 | └── tokenizer_config.json
21 |
22 | `{path_to_llama_weights}` is where you store the checkpoints.
23 |
24 |
25 | ## 1.2. Obtain the Delta Weights of Vicuna:
26 |
27 | Then, you should download the delta weights of Vicuna provided by the original authors. You can find the corresponding links to 7B/13B Vicuna models in the table below.
28 |
29 | |**Model Size**|**Delta Weights Address**|**Version**|
30 | |:-------------:|:-------------:|:-------------:|
31 | |7B|[[Link]](https://huggingface.co/lmsys/vicuna-7b-delta-v0)|0|
32 | |13B|[[Link]](https://huggingface.co/lmsys/vicuna-13b-delta-v0)|0|
33 |
34 |
35 |
36 | > **** After conversion, the directory should look like:
37 |
38 | .
39 | └── ./{path_to_delta_vicuna_weights}/
40 | ├── config.json
41 | ├── generation_config.json
42 | ├── pytorch_model-00001-of-00002.bin
43 | ├── pytorch_model-00002-of-00002.bin
44 | ├── pytorch_model.bin.index.json
45 | ├── special_tokens_map.json
46 | ├── tokenizer.model
47 | └── tokenizer_config.json
48 |
49 | `{path_to_delta_vicuna_weights}` is where you store the delta weights of Vicuna.
50 |
51 | ## 1.3. Combine the Weights:
52 |
53 | When the two sets of weights are ready, you can combine them using tools from the Vicuna team.
54 |
55 | First, install the required library.
56 | ```yaml
57 | pip install git+https://github.com/lm-sys/FastChat.git@v0.1.10
58 | ```
59 |
60 | Then, run the following command.
61 | ```yaml
62 | python -m fastchat.model.apply_delta --base {path_to_llama_weights} --target ./vicuna_ckpt/7b_v0/ --delta {path_to_delta_vicuna_weights}
63 | ```
64 |
65 | > **** Now, the final weights are ready as:
66 |
67 | .
68 | └── ./vicuna_ckpt/7b_v0/
69 | ├── config.json
70 | ├── generation_config.json
71 | ├── pytorch_model-00001-of-00002.bin
72 | ├── pytorch_model-00002-of-00002.bin
73 | ├── pytorch_model.bin.index.json
74 | ├── special_tokens_map.json
75 | ├── tokenizer.model
76 | └── tokenizer_config.json
77 |
78 |
79 |
--------------------------------------------------------------------------------
/pandagpt/pretrained_ckpt/imagebind_ckpt/empty.txt:
--------------------------------------------------------------------------------
1 | empty placeholder
--------------------------------------------------------------------------------
/pandagpt/pretrained_ckpt/pandagpt_ckpt/13b/empty.txt:
--------------------------------------------------------------------------------
1 | empty placeholder
--------------------------------------------------------------------------------
/pandagpt/pretrained_ckpt/pandagpt_ckpt/7b/empty.txt:
--------------------------------------------------------------------------------
1 | empty placeholder
--------------------------------------------------------------------------------
/pandagpt/pretrained_ckpt/vicuna_ckpt/13b_v0/empty.txt:
--------------------------------------------------------------------------------
1 | empty placeholder
--------------------------------------------------------------------------------
/pandagpt/pretrained_ckpt/vicuna_ckpt/7b_v0/empty.txt:
--------------------------------------------------------------------------------
1 | empty placeholder
--------------------------------------------------------------------------------
/pandagpt/requirements.txt:
--------------------------------------------------------------------------------
1 | timm==0.6.7
2 | deepspeed==0.9.2
3 | data
4 | einops==0.6.1
5 | ftfy==6.1.1
6 | iopath==0.1.10
7 | ipdb==0.13.13
8 | numpy==1.24.3
9 | peft==0.3.0
10 | Pillow==9.5.0
11 | PyYAML==6.0
12 | regex==2022.10.31
13 | torchvision==0.14.1
14 | torchaudio==0.13.1
15 | pytorchvideo
16 | fvcore
17 | decord==0.6.0
18 | tqdm==4.64.1
19 | transformers==4.29.1
20 | jupyter
21 | sentencepiece
--------------------------------------------------------------------------------
/result_audios/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_audios/.gitkeep
--------------------------------------------------------------------------------
/result_audios/bird_malicious.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_audios/bird_malicious.pt
--------------------------------------------------------------------------------
/result_audios/bird_malicious.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_audios/bird_malicious.wav
--------------------------------------------------------------------------------
/result_audios/panda-italy-baseline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_audios/panda-italy-baseline.png
--------------------------------------------------------------------------------
/result_audios/panda-italy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_audios/panda-italy.png
--------------------------------------------------------------------------------
/result_images/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/.DS_Store
--------------------------------------------------------------------------------
/result_images/llava-baby-baseline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava-baby-baseline.png
--------------------------------------------------------------------------------
/result_images/llava-crying-baby.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava-crying-baby.png
--------------------------------------------------------------------------------
/result_images/llava-pirate.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava-pirate.png
--------------------------------------------------------------------------------
/result_images/llava-potter.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava-potter.png
--------------------------------------------------------------------------------
/result_images/llava/harrypotter_partial.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava/harrypotter_partial.png
--------------------------------------------------------------------------------
/result_images/llava/harrypotter_partial.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava/harrypotter_partial.pt
--------------------------------------------------------------------------------
/result_images/llava/perturb_full_X.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava/perturb_full_X.jpg
--------------------------------------------------------------------------------
/result_images/llava/perturb_full_X.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava/perturb_full_X.pt
--------------------------------------------------------------------------------
/result_images/llava/perturb_partial_X.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava/perturb_partial_X.jpg
--------------------------------------------------------------------------------
/result_images/llava/perturb_partial_X.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava/perturb_partial_X.pt
--------------------------------------------------------------------------------
/result_images/panda-audio-phishing.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/panda-audio-phishing.png
--------------------------------------------------------------------------------
/result_images/pandagpt/panda_cow_full.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/pandagpt/panda_cow_full.jpg
--------------------------------------------------------------------------------
/result_images/pandagpt/panda_cow_full.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/pandagpt/panda_cow_full.pt
--------------------------------------------------------------------------------
/result_images/pandagpt/panda_cow_partial.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/pandagpt/panda_cow_partial.jpg
--------------------------------------------------------------------------------
/result_images/pandagpt/panda_cow_partial.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/pandagpt/panda_cow_partial.pt
--------------------------------------------------------------------------------