├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── .DS_Store ├── audios │ └── bird_audio.wav └── images │ ├── CT.png │ ├── bird_image.jpg │ └── waterview.jpg ├── header.py ├── llava ├── .DS_Store ├── __init__.py ├── constants.py ├── conversation.py ├── data │ ├── __init__.py │ ├── alpaca-converter.py │ ├── clean_sharegpt.py │ ├── inspect.py │ ├── optional_clean.py │ ├── pretty_json.py │ └── split_long_conversation.py ├── eval │ ├── .DS_Store │ ├── eval_gpt_review.py │ ├── eval_gpt_review_visual.py │ ├── eval_science_qa.py │ ├── eval_science_qa_gpt4.py │ ├── eval_science_qa_gpt4_requery.py │ ├── generate_webpage_data_from_table.py │ ├── model_qa.py │ ├── model_vqa.py │ ├── model_vqa_science.py │ ├── qa_baseline_gpt35.py │ ├── run_llava.py │ ├── summarize_gpt_review.py │ ├── table │ │ ├── answer │ │ │ ├── answer_alpaca-13b.jsonl │ │ │ ├── answer_bard.jsonl │ │ │ ├── answer_gpt35.jsonl │ │ │ ├── answer_llama-13b.jsonl │ │ │ └── answer_vicuna-13b.jsonl │ │ ├── caps_boxes_coco2014_val_80.jsonl │ │ ├── model.jsonl │ │ ├── prompt.jsonl │ │ ├── question.jsonl │ │ ├── results │ │ │ └── test_sqa_llava_13b_v0.json │ │ ├── review │ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl │ │ │ ├── review_bard_vicuna-13b.jsonl │ │ │ ├── review_gpt35_vicuna-13b.jsonl │ │ │ └── review_llama-13b_vicuna-13b.jsonl │ │ ├── reviewer.jsonl │ │ └── rule.json │ └── webpage │ │ ├── figures │ │ ├── alpaca.png │ │ ├── bard.jpg │ │ ├── chatgpt.svg │ │ ├── llama.jpg │ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg │ │ └── vicuna.jpeg │ │ ├── index.html │ │ ├── script.js │ │ └── styles.css ├── llava_injection.py ├── model │ ├── .DS_Store │ ├── __init__.py │ ├── apply_delta.py │ ├── consolidate.py │ ├── llava.py │ ├── llava_mpt.py │ ├── make_delta.py │ ├── mpt │ │ ├── adapt_tokenizer.py │ │ ├── attention.py │ │ ├── blocks.py │ │ ├── configuration_mpt.py │ │ ├── hf_prefixlm_converter.py │ │ ├── meta_init_context.py │ │ ├── modeling_mpt.py │ │ ├── norm.py │ │ └── param_init_fns.py │ └── utils.py ├── pyproject.toml ├── serve │ ├── .DS_Store │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── CornellTech.png │ │ ├── crying_boy.jpg │ │ ├── extreme_ironing.jpg │ │ ├── london.jpg │ │ ├── math.jpg │ │ ├── math.png │ │ ├── poo.jpg │ │ ├── porsche911.jpeg │ │ ├── porsche911.png │ │ ├── steak.jpg │ │ ├── thief.jpg │ │ └── waterview.jpg │ ├── gateway │ │ ├── README.md │ │ └── nginx.conf │ ├── gradio_css.py │ ├── gradio_patch.py │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ └── test_message.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ └── train_mem.py └── utils.py ├── models └── .gitkeep ├── original_images ├── CT.png └── waterview.jpg ├── pandagpt ├── .DS_Store ├── .gitkeep ├── code │ ├── assets │ │ ├── audios │ │ │ ├── bird_audio.wav │ │ │ ├── car_audio.wav │ │ │ └── dog_audio.wav │ │ ├── images │ │ │ ├── bird_image.jpg │ │ │ ├── car_image.jpg │ │ │ └── dog_image.jpg │ │ ├── thermals │ │ │ ├── 190662.jpg │ │ │ └── 210009.jpg │ │ └── videos │ │ │ ├── a.mp4 │ │ │ └── world.mp4 │ ├── config │ │ ├── __init__.py │ │ ├── base.yaml │ │ └── openllama_peft.yaml │ ├── datasets │ │ ├── __init__.py │ │ ├── samplers.py │ │ └── sft_dataset.py │ ├── dsconfig │ │ └── openllama_peft_stage_1.json │ ├── header.py │ ├── model │ │ ├── ImageBind │ │ │ ├── CODE_OF_CONDUCT.md │ │ │ ├── CONTRIBUTING.md │ │ │ ├── LICENSE │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── bpe │ │ │ │ └── bpe_simple_vocab_16e6.txt.gz │ │ │ ├── data.py │ │ │ ├── model_card.md │ │ │ ├── models │ │ │ │ ├── __init__.py │ │ │ │ ├── helpers.py │ │ │ │ ├── imagebind_model.py │ │ │ │ ├── multimodal_preprocessors.py │ │ │ │ └── transformer.py │ │ │ └── requirements.txt │ │ ├── __init__.py │ │ ├── agent.py │ │ ├── modeling_llama.py │ │ └── openllama.py │ ├── scripts │ │ └── train.sh │ ├── train_sft.py │ └── web_demo.py ├── data │ └── empty.txt ├── pandagpt_injection.py ├── pretrained_ckpt │ ├── README.md │ ├── imagebind_ckpt │ │ └── empty.txt │ ├── pandagpt_ckpt │ │ ├── 13b │ │ │ └── empty.txt │ │ └── 7b │ │ │ └── empty.txt │ └── vicuna_ckpt │ │ ├── 13b_v0 │ │ └── empty.txt │ │ └── 7b_v0 │ │ └── empty.txt └── requirements.txt ├── result_audios ├── .gitkeep ├── bird_malicious.pt ├── bird_malicious.wav ├── panda-italy-baseline.png └── panda-italy.png ├── result_images ├── .DS_Store ├── llava-baby-baseline.png ├── llava-crying-baby.png ├── llava-pirate.png ├── llava-potter.png ├── llava │ ├── harrypotter_partial.png │ ├── harrypotter_partial.pt │ ├── perturb_full_X.jpg │ ├── perturb_full_X.pt │ ├── perturb_partial_X.jpg │ └── perturb_partial_X.pt ├── panda-audio-phishing.png └── pandagpt │ ├── panda_cow_full.jpg │ ├── panda_cow_full.pt │ ├── panda_cow_partial.jpg │ └── panda_cow_partial.pt ├── run_llava_injection.ipynb └── run_pandagpt_injection.ipynb /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Eugene Bagdasaryan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/assets/.DS_Store -------------------------------------------------------------------------------- /assets/audios/bird_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/assets/audios/bird_audio.wav -------------------------------------------------------------------------------- /assets/images/CT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/assets/images/CT.png -------------------------------------------------------------------------------- /assets/images/bird_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/assets/images/bird_image.jpg -------------------------------------------------------------------------------- /assets/images/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/assets/images/waterview.jpg -------------------------------------------------------------------------------- /header.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | import types 4 | import deepspeed 5 | from transformers.deepspeed import HfDeepSpeedConfig 6 | import transformers 7 | import numpy as np 8 | from collections import OrderedDict 9 | from torch.utils.data import Dataset, DataLoader 10 | from torch.nn.utils import clip_grad_norm_ 11 | from torch.cuda.amp import autocast, GradScaler 12 | from torch.nn import DataParallel 13 | from torch.optim import lr_scheduler 14 | import torch.optim as optim 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from tqdm import tqdm 18 | import os 19 | import re 20 | import math 21 | import random 22 | import json 23 | import time 24 | import logging 25 | from copy import deepcopy 26 | import ipdb 27 | import argparse 28 | import data 29 | from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig 30 | from torch.nn.utils.rnn import pad_sequence 31 | from peft import LoraConfig, TaskType, get_peft_model 32 | 33 | logging.getLogger("transformers").setLevel(logging.WARNING) 34 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) 35 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 36 | -------------------------------------------------------------------------------- /llava/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/.DS_Store -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | -------------------------------------------------------------------------------- /llava/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/data/__init__.py -------------------------------------------------------------------------------- /llava/data/alpaca-converter.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import pathlib 4 | 5 | # Prompt from stanford alpaca's training script 6 | PROMPT_DICT = { 7 | "prompt_input": ( 8 | "Below is an instruction that describes a task, paired with an input that provides further context. " 9 | "Write a response that appropriately completes the request.\n\n" 10 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 11 | ), 12 | "prompt_no_input": ( 13 | "Below is an instruction that describes a task. " 14 | "Write a response that appropriately completes the request.\n\n" 15 | "### Instruction:\n{instruction}\n\n### Response:" 16 | ), 17 | } 18 | 19 | 20 | def main(args): 21 | data_path = pathlib.Path(args.data_path) 22 | with data_path.open() as f: 23 | data = json.load(f) 24 | 25 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"] 26 | sources = [ 27 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) 28 | for example in data 29 | ] 30 | targets = [example['output'] for example in data] 31 | 32 | new_data = [] 33 | cnt = 1 34 | for s, t in zip(sources, targets): 35 | new_data.append({ 36 | 'id': str(cnt), 37 | 'conversations': [ 38 | { 39 | 'from': 'human', 40 | 'value': s, 41 | }, 42 | { 43 | 'from': 'gpt', 44 | 'value': t, 45 | } 46 | ] 47 | }) 48 | cnt += 1 49 | 50 | json.dump(new_data, open(args.output_path, 'w'), indent=2) 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('--data_path', type=str, default='alpaca-data.json') 55 | parser.add_argument('--output_path', type=str, default='alpaca-data-conversation.json') 56 | args = parser.parse_args() 57 | main(args) 58 | 59 | -------------------------------------------------------------------------------- /llava/data/clean_sharegpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | - Convert html to markdown with basic data cleaning. 3 | - Deduplication. 4 | 5 | Usage: 6 | python3 -m fastchat.data.clean_sharegpt --in sharegpt_html.json --out sharegpt_clean.json 7 | """ 8 | import argparse 9 | from concurrent.futures import ProcessPoolExecutor 10 | import json 11 | import logging 12 | import re 13 | from typing import Dict, Union 14 | 15 | import bs4 16 | import markdownify # == 0.11.6 17 | from tqdm import tqdm 18 | 19 | 20 | div_pattern = re.compile("") 21 | span_pattern = re.compile("") 22 | code_lang_pattern = re.compile( 23 | "```\s*" + "(.*?)" + "(?:Copy code)+" + "(.+?)" + "\s*?```", re.DOTALL 24 | ) 25 | code_lang_format = "```\g<1>\n\g<2>\n```" 26 | regenerate_pattern = re.compile("\d+ / \d+") 27 | copy_chars_pattern = re.compile("Copy\d+ chars / \d+ words") 28 | copy_code_pattern = re.compile("```(.*?)Copy code\s*```") 29 | 30 | 31 | def reformat_code(val: str) -> str: 32 | # Input code format is: 33 | # ``` 34 | # $Copy code$ 35 | # 36 | # ``` 37 | # This function convert it into the correct markdown format 38 | return re.sub(code_lang_pattern, code_lang_format, val) 39 | 40 | 41 | def html_to_markdown(val: str) -> str: 42 | # Remove all
. This is required to make intent work in code blocks. 43 | val = re.sub(div_pattern, "", val) 44 | # Remove all . This is required to make underscores work in code blocks. 45 | val = re.sub(span_pattern, "", val) 46 | # Markdown to html 47 | val = markdownify.markdownify(val).strip() 48 | # Reformat code 49 | val = reformat_code(val) 50 | 51 | # Remove noisy "[number] / [number]" at the beginning 52 | noise = re.search(regenerate_pattern, val) 53 | if noise and noise.start() == 0: 54 | val = val[noise.end() :] 55 | # Remove noisy "Copy[number] chars / [number] words" 56 | val = re.sub(copy_chars_pattern, "", val) 57 | # Remove empty code block ```\nCopy code\n``` 58 | val = re.sub(copy_code_pattern, "", val) 59 | 60 | # Strip 61 | val = val.replace("\n\n\n", "\n").strip() 62 | 63 | return val 64 | 65 | 66 | def contain_blocked_words(val: str) -> bool: 67 | blocked_words = ["openai", "chatgpt"] 68 | for w in blocked_words: 69 | if w in val.lower(): 70 | return True 71 | return False 72 | 73 | 74 | def clean_html_one_sample(sample): 75 | roles = ["human", "gpt"] 76 | 77 | if len(sample["conversations"]) <= 1: 78 | return (sample, 1) 79 | 80 | # Adjust the offset for cases like https://sharegpt.com/c/VyaZlh4 81 | if sample["conversations"][0]["from"] != "human": 82 | sample["conversations"] = sample["conversations"][1:] 83 | if len(sample["conversations"]) <= 1: 84 | return (sample, 1) 85 | 86 | if sample["conversations"][-1]["from"] == "human": 87 | sample["conversations"] = sample["conversations"][:-1] 88 | if len(sample["conversations"]) <= 1: 89 | return (sample, 1) 90 | 91 | for i, c in enumerate(sample["conversations"]): 92 | if c["from"] != roles[i % 2]: 93 | return (sample, 2) 94 | 95 | if contain_blocked_words(c["value"]): 96 | return (sample, 3) 97 | 98 | try: 99 | new_val = html_to_markdown(c["value"]) 100 | except (bs4.builder.ParserRejectedMarkup, AssertionError): 101 | return (sample, 4) 102 | 103 | c["value"] = new_val 104 | 105 | return (sample, 0) 106 | 107 | 108 | def clean_html_all(content, begin, end): 109 | """ 110 | Clean the source html files. 111 | """ 112 | cnt_skip = 0 113 | cnt_blocked_words = 0 114 | cnt_wrong_format = 0 115 | cnt_parser_error = 0 116 | cnt_too_short = 0 117 | cnt_id_duplication = 0 118 | cnt_value_duplication = 0 119 | cnt_tag = 0 120 | 121 | content = content[begin:end] 122 | processed = [] 123 | with ProcessPoolExecutor() as executor: 124 | for result in tqdm( 125 | executor.map(clean_html_one_sample, content), total=len(content) 126 | ): 127 | processed.append(result) 128 | 129 | visited = {} 130 | new_content = [] 131 | for sample, error_code in tqdm(processed): 132 | cid = sample["id"] 133 | skipped = True 134 | 135 | if error_code != 0: 136 | if error_code == 1: 137 | print(f"id {cid} is too short") 138 | cnt_too_short += 1 139 | elif error_code == 2: 140 | print(f"id {cid} has a wrong format") 141 | cnt_wrong_format += 1 142 | elif error_code == 3: 143 | print(f"id {cid} contains blocked words") 144 | cnt_blocked_words += 1 145 | elif error_code == 4: 146 | print(f"id {cid} contains parser errors") 147 | cnt_parser_error += 1 148 | else: 149 | raise ValueError(f"Invalid error_code: {error_code}") 150 | elif cid in visited: 151 | print(f"id {cid} is an id duplication of {visited[cid]}") 152 | cnt_id_duplication += 1 153 | elif ( 154 | sample["conversations"][1]["value"], 155 | len(sample["conversations"]), 156 | ) in visited: 157 | key = (sample["conversations"][1]["value"], len(sample["conversations"])) 158 | print(f"id {cid} is a value duplication of {visited[key]}") 159 | cnt_value_duplication += 1 160 | else: 161 | key = (sample["conversations"][1]["value"], len(sample["conversations"])) 162 | visited[cid] = visited[key] = cid 163 | skipped = False 164 | 165 | if not skipped: 166 | new_content.append(sample) 167 | else: 168 | cnt_skip += 1 169 | 170 | print( 171 | f"total: {len(content)}, skip: {cnt_skip}, new: {len(new_content)}, " 172 | f"cnt_blocked_words: {cnt_blocked_words}, cnt_parser_error: {cnt_parser_error}, " 173 | f"cnt_wrong_format: {cnt_wrong_format}, " 174 | f"cnt_too_short: {cnt_too_short}, cnt_id_duplication: {cnt_id_duplication}, " 175 | f"cnt_value_duplication: {cnt_value_duplication}, " 176 | ) 177 | 178 | return new_content 179 | 180 | 181 | def main(args): 182 | content = json.load(open(args["in_file"], "r")) 183 | content = clean_html_all(content, args["begin"], args["end"]) 184 | json.dump(content, open(args["out_file"], "w"), indent=2) 185 | 186 | 187 | if __name__ == "__main__": 188 | parser = argparse.ArgumentParser() 189 | parser.add_argument("--in-file", type=str, required=True) 190 | parser.add_argument("--out-file", type=str, default="sharegpt_clean.json") 191 | parser.add_argument("--begin", type=int) 192 | parser.add_argument("--end", type=int) 193 | parser.add_argument("--debug", action="store_true") 194 | args = parser.parse_args() 195 | main(vars(args)) 196 | -------------------------------------------------------------------------------- /llava/data/inspect.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.data.inspect --in sharegpt_20230322_clean_lang_split.json 4 | """ 5 | import argparse 6 | import json 7 | 8 | import tqdm 9 | 10 | 11 | if __name__ == "__main__": 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--in-file", type=str, required=True) 14 | parser.add_argument("--begin", type=int) 15 | args = parser.parse_args() 16 | 17 | content = json.load(open(args.in_file, "r")) 18 | for sample in tqdm.tqdm(content[args.begin:]): 19 | print(f"id: {sample['id']}") 20 | for conv in sample["conversations"]: 21 | print(conv["from"] + ": ") 22 | print(conv["value"]) 23 | input() 24 | -------------------------------------------------------------------------------- /llava/data/optional_clean.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.data.optional_clean --lang en --reduce-rep --in sharegpt_clean.json --out output.json 4 | python3 -m fastchat.data.optional_clean --skip-lang en --reduce-rep --in sharegpt_clean.json --out output.json 5 | """ 6 | import argparse 7 | import json 8 | import re 9 | 10 | import polyglot 11 | from polyglot.detect import Detector 12 | import pycld2 13 | from tqdm import tqdm 14 | 15 | 16 | def skip(conv, args): 17 | # Remove certain languages 18 | if args.lang != "all" or args.skip_lang is not None: 19 | text = "\n".join([x["value"] for x in conv["conversations"]]) 20 | try: 21 | lang_code = Detector(text).language.code 22 | except (pycld2.error, polyglot.detect.base.UnknownLanguage): 23 | lang_code = "unknown" 24 | 25 | if args.lang != "all" and lang_code != args.lang: 26 | return True 27 | 28 | if lang_code == args.skip_lang: 29 | return True 30 | 31 | # Remove repetitive numbers 32 | if args.reduce_rep: 33 | for sentence in conv["conversations"]: 34 | val = sentence["value"] 35 | sub = re.search(r"(\d)\1{8}", val) 36 | if sub is not None: 37 | return True 38 | 39 | return False 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--in-file", type=str, required=True) 45 | parser.add_argument("--out-file", type=str, default="") 46 | parser.add_argument("--lang", type=str, default="all", 47 | choices=["all", "en"]) 48 | parser.add_argument("--skip-lang", type=str) 49 | # NOTE: Be careful about reduce_rep which may remove some good data. 50 | # For example, addresses could have long consecutive 0's 51 | parser.add_argument("--reduce-rep", action="store_true") 52 | args = parser.parse_args() 53 | 54 | in_file = args.in_file 55 | out_file = args.out_file 56 | lang = args.lang 57 | skip_lang = args.skip_lang 58 | reduce_rep = args.reduce_rep 59 | assert (lang == "all" or skip_lang is None) 60 | 61 | if out_file == "": 62 | out_file = "sharegpt_clean" 63 | if lang != "all": 64 | out_file += "_" + lang 65 | if skip_lang is not None: 66 | out_file += "_skip_" + skip_lang 67 | if reduce_rep: 68 | out_file += "_reduce_rep" 69 | out_file += ".json" 70 | 71 | content = json.load(open(in_file, "r")) 72 | num_conv = len(content) 73 | 74 | new_content = [] 75 | for conv in tqdm(content): 76 | if not skip(conv, args): 77 | new_content.append(conv) 78 | 79 | print(f"return {len(new_content)} out of {len(content)}, start dump ...") 80 | json.dump(new_content, open(out_file, "w"), indent=2) 81 | -------------------------------------------------------------------------------- /llava/data/pretty_json.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 pretty_json.py --in in.json --out out.json 4 | """ 5 | 6 | import argparse 7 | import json 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--in-file", type=str, required=True) 13 | parser.add_argument("--out-file", type=str, required=True) 14 | args = parser.parse_args() 15 | 16 | with open(args.in_file, "r") as fin: 17 | data = json.load(fin) 18 | 19 | with open(args.out_file, "w") as fout: 20 | json.dump(data, fout, indent=2) 21 | -------------------------------------------------------------------------------- /llava/data/split_long_conversation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Split long conversations based on certain max length. 3 | 4 | Usage: python3 -m fastchat.data.split_long_conversation \ 5 | --in sharegpt_clean.json \ 6 | --out sharegpt_split.json \ 7 | --model-name-or-path $ 8 | """ 9 | import argparse 10 | import json 11 | from typing import Dict, Sequence, Optional 12 | 13 | import transformers 14 | import tqdm 15 | 16 | from llava import conversation as conversation_lib 17 | 18 | DEFAULT_PAD_TOKEN = "[PAD]" 19 | BEGIN_SIGNAL = "### " 20 | END_SIGNAL = "\n" 21 | 22 | 23 | def split_sample(sample, start_idx, end_idx): 24 | # only ends in the bot because otherwise the last human part is useless. 25 | end_speaker = sample["conversations"][end_idx]["from"] 26 | end_idx = end_idx + 1 if end_speaker != "human" else end_idx 27 | return { 28 | "id": sample["id"] + "_" + str(start_idx), 29 | "conversations": sample["conversations"][start_idx:end_idx] 30 | } 31 | 32 | 33 | def split_contents(content, begin, end, tokenizer, max_length): 34 | """ 35 | Keep the maximum round of conversations within the max token length constraint 36 | """ 37 | content = content[begin:end] 38 | new_content = [] 39 | 40 | for sample in tqdm.tqdm(content): 41 | tokenized_lens = [] 42 | 43 | for c in sample["conversations"]: 44 | from_str = c["from"] 45 | if from_str.lower() == "human": 46 | from_str = conversation_lib.default_conversation.roles[0] 47 | elif from_str.lower() == "gpt": 48 | from_str = conversation_lib.default_conversation.roles[1] 49 | else: 50 | from_str = 'unknown' 51 | 52 | sentence = (BEGIN_SIGNAL + from_str + ": " + c["value"] + 53 | END_SIGNAL) 54 | length = tokenizer(sentence, return_tensors="pt", padding="longest" 55 | ).input_ids.ne(tokenizer.pad_token_id).sum().item() 56 | tokenized_lens.append(length) 57 | 58 | num_tokens = 0 59 | start_idx = 0 60 | for idx, l in enumerate(tokenized_lens): 61 | # TODO: shall we also only starts from a specific speaker? 62 | if num_tokens + l > max_length: 63 | new_content.append(split_sample(sample, start_idx, idx)) 64 | start_idx = idx 65 | num_tokens = l 66 | else: 67 | num_tokens += l 68 | if idx == len(tokenized_lens) - 1: 69 | new_content.append(split_sample(sample, start_idx, idx)) 70 | 71 | print(f"total: {len(content)}, new: {len(new_content)}") 72 | return new_content 73 | 74 | 75 | def main(args): 76 | content = json.load(open(args.in_file, "r")) 77 | tokenizer = transformers.AutoTokenizer.from_pretrained( 78 | args.model_name_or_path, 79 | model_max_length=args.max_length, 80 | padding_side="right", 81 | use_fast=False, 82 | ) 83 | if tokenizer.pad_token is None: 84 | tokenizer.add_special_tokens(dict(pad_token=DEFAULT_PAD_TOKEN)) 85 | content = split_contents(content, args.begin, args.end, 86 | tokenizer, args.max_length) 87 | json.dump(content, open(args.out_file, "w"), indent=2) 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument("--in-file", type=str, required=True) 93 | parser.add_argument("--out-file", type=str, default="sharegpt_split.json") 94 | parser.add_argument("--begin", type=int) 95 | parser.add_argument("--end", type=int) 96 | parser.add_argument("--model-name-or-path", type=str, required=True) 97 | parser.add_argument("--max-length", type=int, default=2304) 98 | args = parser.parse_args() 99 | main(args) 100 | -------------------------------------------------------------------------------- /llava/eval/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/eval/.DS_Store -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | @ray.remote(num_cpus=4) 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(1) 32 | 33 | print('success!') 34 | return response['choices'][0]['message']['content'] 35 | 36 | 37 | def parse_score(review): 38 | try: 39 | score_pair = review.split('\n')[0] 40 | score_pair = score_pair.replace(',', ' ') 41 | sp = score_pair.split(' ') 42 | if len(sp) == 2: 43 | return [float(sp[0]), float(sp[1])] 44 | else: 45 | print('error', review) 46 | return [-1, -1] 47 | except Exception as e: 48 | print(e) 49 | print('error', review) 50 | return [-1, -1] 51 | 52 | 53 | if __name__ == '__main__': 54 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 55 | parser.add_argument('-q', '--question') 56 | # parser.add_argument('-a', '--answer') 57 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 58 | parser.add_argument('-r', '--rule') 59 | parser.add_argument('-o', '--output') 60 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 61 | args = parser.parse_args() 62 | 63 | ray.init() 64 | 65 | f_q = open(os.path.expanduser(args.question)) 66 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 67 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 68 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 69 | 70 | review_file = open(f'{args.output}', 'w') 71 | 72 | js_list = [] 73 | handles = [] 74 | idx = 0 75 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 76 | # if idx == 1: 77 | # break 78 | 79 | ques = json.loads(ques_js) 80 | ans1 = json.loads(ans1_js) 81 | ans2 = json.loads(ans2_js) 82 | 83 | category = json.loads(ques_js)['category'] 84 | if category in rule_dict: 85 | rule = rule_dict[category] 86 | else: 87 | rule = rule_dict['default'] 88 | prompt = rule['prompt'] 89 | role = rule['role'] 90 | content = (f'[Question]\n{ques["text"]}\n\n' 91 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 92 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 93 | f'[System]\n{prompt}\n\n') 94 | js_list.append({ 95 | 'id': idx+1, 96 | 'question_id': ques['question_id'], 97 | 'answer1_id': ans1['answer_id'], 98 | 'answer2_id': ans2['answer_id'], 99 | 'category': category}) 100 | idx += 1 101 | handles.append(get_eval.remote(content, args.max_tokens)) 102 | # To avoid the rate limit set by OpenAI 103 | time.sleep(1) 104 | 105 | reviews = ray.get(handles) 106 | for idx, review in enumerate(reviews): 107 | scores = parse_score(review) 108 | js_list[idx]['content'] = review 109 | js_list[idx]['tuple'] = scores 110 | review_file.write(json.dumps(js_list[idx]) + '\n') 111 | review_file.close() 112 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | @ray.remote(num_cpus=4) 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(1) 32 | 33 | print('success!') 34 | return response['choices'][0]['message']['content'] 35 | 36 | 37 | def parse_score(review): 38 | try: 39 | score_pair = review.split('\n')[0] 40 | score_pair = score_pair.replace(',', ' ') 41 | sp = score_pair.split(' ') 42 | if len(sp) == 2: 43 | return [float(sp[0]), float(sp[1])] 44 | else: 45 | print('error', review) 46 | return [-1, -1] 47 | except Exception as e: 48 | print(e) 49 | print('error', review) 50 | return [-1, -1] 51 | 52 | 53 | if __name__ == '__main__': 54 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 55 | parser.add_argument('-q', '--question') 56 | parser.add_argument('-c', '--context') 57 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 58 | parser.add_argument('-r', '--rule') 59 | parser.add_argument('-o', '--output') 60 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 61 | args = parser.parse_args() 62 | 63 | ray.init() 64 | 65 | f_q = open(os.path.expanduser(args.question)) 66 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 67 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 68 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 69 | 70 | review_file = open(f'{args.output}', 'w') 71 | 72 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 73 | image_to_context = {context['image']: context for context in context_list} 74 | 75 | js_list = [] 76 | handles = [] 77 | idx = 0 78 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 79 | ques = json.loads(ques_js) 80 | ans1 = json.loads(ans1_js) 81 | ans2 = json.loads(ans2_js) 82 | 83 | inst = image_to_context[ques['image']] 84 | cap_str = '\n'.join(inst['captions']) 85 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 86 | 87 | category = json.loads(ques_js)['category'] 88 | if category in rule_dict: 89 | rule = rule_dict[category] 90 | else: 91 | assert False, f"Visual QA category not found in rule file: {category}." 92 | prompt = rule['prompt'] 93 | role = rule['role'] 94 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 95 | f'[Question]\n{ques["text"]}\n\n' 96 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 97 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 98 | f'[System]\n{prompt}\n\n') 99 | js_list.append({ 100 | 'id': idx+1, 101 | 'question_id': ques['question_id'], 102 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 103 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 104 | 'category': category}) 105 | idx += 1 106 | handles.append(get_eval.remote(content, args.max_tokens)) 107 | # To avoid the rate limit set by OpenAI 108 | time.sleep(1) 109 | 110 | reviews = ray.get(handles) 111 | for idx, review in enumerate(reviews): 112 | scores = parse_score(review) 113 | js_list[idx]['content'] = review 114 | js_list[idx]['tuple'] = scores 115 | review_file.write(json.dumps(js_list[idx]) + '\n') 116 | review_file.close() 117 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base-dir', type=str) 11 | parser.add_argument('--result-file', type=str) 12 | parser.add_argument('--output-file', type=str) 13 | parser.add_argument('--output-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | predictions = [json.loads(line) for line in open(args.result_file)] 45 | predictions = {pred['question_id']: pred for pred in predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | results = {'correct': [], 'incorrect': []} 49 | sqa_results = {} 50 | sqa_results['acc'] = None 51 | sqa_results['correct'] = None 52 | sqa_results['count'] = None 53 | sqa_results['results'] = {} 54 | sqa_results['outputs'] = {} 55 | 56 | for prob_id, prob in split_problems.items(): 57 | if prob_id not in predictions: 58 | continue 59 | pred = predictions[prob_id] 60 | pred_text = pred['text'] 61 | 62 | pattern = re.compile(r'The answer is ([A-Z]).') 63 | res = pattern.findall(pred_text) 64 | if len(res) == 1: 65 | answer = res[0] # 'A', 'B', ... 66 | else: 67 | answer = "FAILED" 68 | 69 | pred_idx = get_pred_idx(answer, prob['choices'], args.options) 70 | 71 | analysis = { 72 | 'question_id': prob_id, 73 | 'parsed_ans': answer, 74 | 'ground_truth': args.options[prob['answer']], 75 | 'question': pred['prompt'], 76 | 'pred': pred_text, 77 | 'is_multimodal': '' in pred['prompt'], 78 | } 79 | 80 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) 81 | sqa_results['outputs'][prob_id] = pred_text 82 | 83 | if pred_idx == prob['answer']: 84 | results['correct'].append(analysis) 85 | else: 86 | results['incorrect'].append(analysis) 87 | 88 | correct = len(results['correct']) 89 | total = len(results['correct']) + len(results['incorrect']) 90 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 91 | 92 | sqa_results['acc'] = correct / total * 100 93 | sqa_results['correct'] = correct 94 | sqa_results['count'] = total 95 | 96 | with open(args.output_file, 'w') as f: 97 | json.dump(results, f, indent=2) 98 | with open(args.output_result, 'w') as f: 99 | json.dump(sqa_results, f, indent=2) 100 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4_requery.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--requery-result', type=str) 14 | parser.add_argument('--our-result', type=str) 15 | parser.add_argument('--output-result', type=str) 16 | parser.add_argument('--split', type=str, default='test') 17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 18 | return parser.parse_args() 19 | 20 | 21 | def convert_caps(results): 22 | fakecaps = [] 23 | for result in results: 24 | image_id = result['question_id'] 25 | caption = result['text'] 26 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 27 | return fakecaps 28 | 29 | 30 | def get_pred_idx(prediction, choices, options): 31 | """ 32 | Get the index (e.g. 2) from the prediction (e.g. 'C') 33 | """ 34 | if prediction in options[:len(choices)]: 35 | return options.index(prediction) 36 | else: 37 | return random.choice(range(len(choices))) 38 | 39 | 40 | if __name__ == "__main__": 41 | args = get_args() 42 | 43 | base_dir = args.base_dir 44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 45 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 46 | our_predictions = [json.loads(line) for line in open(args.our_result)] 47 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 48 | split_problems = {idx: problems[idx] for idx in split_indices} 49 | 50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)] 51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions} 52 | 53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 54 | 55 | results = defaultdict(lambda: 0) 56 | 57 | sqa_results = {} 58 | sqa_results['acc'] = None 59 | sqa_results['correct'] = None 60 | sqa_results['count'] = None 61 | sqa_results['results'] = {} 62 | sqa_results['outputs'] = {} 63 | 64 | for prob_id, prob in split_problems.items(): 65 | if prob_id not in our_predictions: 66 | assert False 67 | if prob_id not in gpt4_predictions: 68 | assert False 69 | our_pred = our_predictions[prob_id]['text'] 70 | gpt4_pred = gpt4_predictions[prob_id] 71 | if prob_id not in requery_predictions: 72 | results['missing_requery'] += 1 73 | requery_pred = "MISSING" 74 | else: 75 | requery_pred = requery_predictions[prob_id]['text'] 76 | 77 | pattern = re.compile(r'The answer is ([A-Z]).') 78 | our_res = pattern.findall(our_pred) 79 | if len(our_res) == 1: 80 | our_answer = our_res[0] # 'A', 'B', ... 81 | else: 82 | our_answer = "FAILED" 83 | 84 | requery_res = pattern.findall(requery_pred) 85 | if len(requery_res) == 1: 86 | requery_answer = requery_res[0] # 'A', 'B', ... 87 | else: 88 | requery_answer = "FAILED" 89 | 90 | gpt4_res = pattern.findall(gpt4_pred) 91 | if len(gpt4_res) == 1: 92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 93 | else: 94 | gpt4_answer = "FAILED" 95 | 96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) 99 | 100 | results['total'] += 1 101 | 102 | if gpt4_answer == 'FAILED': 103 | results['gpt4_failed'] += 1 104 | if gpt4_pred_idx == prob['answer']: 105 | results['gpt4_correct'] += 1 106 | if our_pred_idx == prob['answer']: 107 | results['gpt4_ourvisual_correct'] += 1 108 | elif gpt4_pred_idx == prob['answer']: 109 | results['gpt4_correct'] += 1 110 | results['gpt4_ourvisual_correct'] += 1 111 | 112 | if our_pred_idx == prob['answer']: 113 | results['our_correct'] += 1 114 | 115 | if requery_answer == 'FAILED': 116 | sqa_results['results'][prob_id] = our_pred_idx 117 | if our_pred_idx == prob['answer']: 118 | results['requery_correct'] += 1 119 | else: 120 | sqa_results['results'][prob_id] = requery_pred_idx 121 | if requery_pred_idx == prob['answer']: 122 | results['requery_correct'] += 1 123 | else: 124 | print(f""" 125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} 126 | Our ({our_answer}): {our_pred} 127 | GPT-4 ({gpt4_answer}): {gpt4_pred} 128 | Requery ({requery_answer}): {requery_pred} 129 | print("=====================================") 130 | """) 131 | 132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 133 | results['correct_upperbound'] += 1 134 | 135 | total = results['total'] 136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') 137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') 138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') 140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') 141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 142 | 143 | sqa_results['acc'] = results["requery_correct"] / total * 100 144 | sqa_results['correct'] = results["requery_correct"] 145 | sqa_results['count'] = total 146 | 147 | with open(args.output_result, 'w') as f: 148 | json.dump(sqa_results, f, indent=2) 149 | 150 | -------------------------------------------------------------------------------- /llava/eval/generate_webpage_data_from_table.py: -------------------------------------------------------------------------------- 1 | """Generate json file for webpage.""" 2 | import json 3 | import os 4 | import re 5 | 6 | # models = ['llama', 'alpaca', 'gpt35', 'bard'] 7 | models = ['vicuna'] 8 | 9 | 10 | def read_jsonl(path: str, key: str=None): 11 | data = [] 12 | with open(os.path.expanduser(path)) as f: 13 | for line in f: 14 | if not line: 15 | continue 16 | data.append(json.loads(line)) 17 | if key is not None: 18 | data.sort(key=lambda x: x[key]) 19 | data = {item[key]: item for item in data} 20 | return data 21 | 22 | 23 | def trim_hanging_lines(s: str, n: int) -> str: 24 | s = s.strip() 25 | for _ in range(n): 26 | s = s.split('\n', 1)[1].strip() 27 | return s 28 | 29 | 30 | if __name__ == '__main__': 31 | questions = read_jsonl('table/question.jsonl', key='question_id') 32 | 33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') 34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') 35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') 36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') 37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') 38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') 39 | 40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') 41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') 42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') 43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') 44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') 45 | 46 | records = [] 47 | for qid in questions.keys(): 48 | r = { 49 | 'id': qid, 50 | 'category': questions[qid]['category'], 51 | 'question': questions[qid]['text'], 52 | 'answers': { 53 | # 'alpaca': alpaca_answers[qid]['text'], 54 | # 'llama': llama_answers[qid]['text'], 55 | # 'bard': bard_answers[qid]['text'], 56 | # 'gpt35': gpt35_answers[qid]['text'], 57 | 'vicuna': vicuna_answers[qid]['text'], 58 | 'ours': ours_answers[qid]['text'], 59 | }, 60 | 'evaluations': { 61 | # 'alpaca': review_alpaca[qid]['text'], 62 | # 'llama': review_llama[qid]['text'], 63 | # 'bard': review_bard[qid]['text'], 64 | 'vicuna': review_vicuna[qid]['content'], 65 | # 'gpt35': review_gpt35[qid]['text'], 66 | }, 67 | 'scores': { 68 | 'vicuna': review_vicuna[qid]['tuple'], 69 | # 'alpaca': review_alpaca[qid]['score'], 70 | # 'llama': review_llama[qid]['score'], 71 | # 'bard': review_bard[qid]['score'], 72 | # 'gpt35': review_gpt35[qid]['score'], 73 | }, 74 | } 75 | 76 | # cleanup data 77 | cleaned_evals = {} 78 | for k, v in r['evaluations'].items(): 79 | v = v.strip() 80 | lines = v.split('\n') 81 | # trim the first line if it's a pair of numbers 82 | if re.match(r'\d+[, ]+\d+', lines[0]): 83 | lines = lines[1:] 84 | v = '\n'.join(lines) 85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') 86 | 87 | r['evaluations'] = cleaned_evals 88 | records.append(r) 89 | 90 | # Reorder the records, this is optional 91 | for r in records: 92 | if r['id'] <= 20: 93 | r['id'] += 60 94 | else: 95 | r['id'] -= 20 96 | for r in records: 97 | if r['id'] <= 50: 98 | r['id'] += 10 99 | elif 50 < r['id'] <= 60: 100 | r['id'] -= 50 101 | for r in records: 102 | if r['id'] == 7: 103 | r['id'] = 1 104 | elif r['id'] < 7: 105 | r['id'] += 1 106 | 107 | records.sort(key=lambda x: x['id']) 108 | 109 | # Write to file 110 | with open('webpage/data.json', 'w') as f: 111 | json.dump({'questions': records, 'models': models}, f, indent=2) 112 | -------------------------------------------------------------------------------- /llava/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.conversation import default_conversation 10 | from llava.utils import disable_torch_init 11 | 12 | 13 | # new stopping implementation 14 | class KeywordsStoppingCriteria(StoppingCriteria): 15 | def __init__(self, keywords, tokenizer, input_ids): 16 | self.keywords = keywords 17 | self.tokenizer = tokenizer 18 | self.start_len = None 19 | self.input_ids = input_ids 20 | 21 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 22 | if self.start_len is None: 23 | self.start_len = self.input_ids.shape[1] 24 | else: 25 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 26 | for keyword in self.keywords: 27 | if keyword in outputs: 28 | return True 29 | return False 30 | 31 | 32 | @torch.inference_mode() 33 | def eval_model(model_name, questions_file, answers_file): 34 | # Model 35 | disable_torch_init() 36 | model_name = os.path.expanduser(model_name) 37 | tokenizer = AutoTokenizer.from_pretrained(model_name) 38 | model = AutoModelForCausalLM.from_pretrained(model_name, 39 | torch_dtype=torch.float16).cuda() 40 | 41 | 42 | ques_file = open(os.path.expanduser(questions_file), "r") 43 | ans_file = open(os.path.expanduser(answers_file), "w") 44 | for i, line in enumerate(tqdm(ques_file)): 45 | idx = json.loads(line)["question_id"] 46 | qs = json.loads(line)["text"] 47 | cat = json.loads(line)["category"] 48 | conv = default_conversation.copy() 49 | conv.append_message(conv.roles[0], qs) 50 | prompt = conv.get_prompt() 51 | inputs = tokenizer([prompt]) 52 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 53 | stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids) 54 | output_ids = model.generate( 55 | input_ids, 56 | do_sample=True, 57 | temperature=0.7, 58 | max_new_tokens=1024, 59 | stopping_criteria=[stopping_criteria]) 60 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 61 | try: 62 | index = outputs.index(conv.sep, len(prompt)) 63 | except ValueError: 64 | outputs += conv.sep 65 | index = outputs.index(conv.sep, len(prompt)) 66 | 67 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 68 | ans_id = shortuuid.uuid() 69 | ans_file.write(json.dumps({"question_id": idx, 70 | "text": outputs, 71 | "answer_id": ans_id, 72 | "model_id": model_name, 73 | "metadata": {}}) + "\n") 74 | ans_file.flush() 75 | ans_file.close() 76 | 77 | if __name__ == "__main__": 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 80 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 81 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 82 | args = parser.parse_args() 83 | 84 | eval_model(args.model_name, args.question_file, args.answers_file) 85 | -------------------------------------------------------------------------------- /llava/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import concurrent.futures 8 | 9 | import openai 10 | import tqdm 11 | import shortuuid 12 | 13 | MODEL = 'gpt-3.5-turbo' 14 | MODEL_ID = 'gpt-3.5-turbo:20230327' 15 | 16 | def get_answer(question_id: int, question: str, max_tokens: int): 17 | ans = { 18 | 'answer_id': shortuuid.uuid(), 19 | 'question_id': question_id, 20 | 'model_id': MODEL_ID, 21 | } 22 | for _ in range(3): 23 | try: 24 | response = openai.ChatCompletion.create( 25 | model=MODEL, 26 | messages=[{ 27 | 'role': 'system', 28 | 'content': 'You are a helpful assistant.' 29 | }, { 30 | 'role': 'user', 31 | 'content': question, 32 | }], 33 | max_tokens=max_tokens, 34 | ) 35 | ans['text'] = response['choices'][0]['message']['content'] 36 | return ans 37 | except Exception as e: 38 | print('[ERROR]', e) 39 | ans['text'] = '#ERROR#' 40 | time.sleep(1) 41 | return ans 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 46 | parser.add_argument('-q', '--question') 47 | parser.add_argument('-o', '--output') 48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 49 | args = parser.parse_args() 50 | 51 | questions_dict = {} 52 | with open(os.path.expanduser(args.question)) as f: 53 | for line in f: 54 | if not line: 55 | continue 56 | q = json.loads(line) 57 | questions_dict[q['question_id']] = q['text'] 58 | 59 | answers = [] 60 | 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 62 | futures = [] 63 | for qid, question in questions_dict.items(): 64 | future = executor.submit(get_answer, qid, question, args.max_tokens) 65 | futures.append(future) 66 | 67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 68 | answers.append(future.result()) 69 | 70 | answers.sort(key=lambda x: x['question_id']) 71 | 72 | with open(os.path.expanduser(args.output), 'w') as f: 73 | table = [json.dumps(ans) for ans in answers] 74 | f.write('\n'.join(table)) 75 | -------------------------------------------------------------------------------- /llava/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM 3 | import torch 4 | import os 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.utils import disable_torch_init 7 | from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria 8 | from llava.model import * 9 | from llava.model.utils import KeywordsStoppingCriteria 10 | 11 | from PIL import Image 12 | 13 | import os 14 | import requests 15 | from PIL import Image 16 | from io import BytesIO 17 | 18 | 19 | DEFAULT_IMAGE_TOKEN = "" 20 | DEFAULT_IMAGE_PATCH_TOKEN = "" 21 | DEFAULT_IM_START_TOKEN = "" 22 | DEFAULT_IM_END_TOKEN = "" 23 | 24 | 25 | def load_image(image_file): 26 | if image_file.startswith('http') or image_file.startswith('https'): 27 | response = requests.get(image_file) 28 | image = Image.open(BytesIO(response.content)).convert('RGB') 29 | else: 30 | image = Image.open(image_file).convert('RGB') 31 | return image 32 | 33 | 34 | def eval_model(args): 35 | # Model 36 | disable_torch_init() 37 | model_name = os.path.expanduser(args.model_name) 38 | tokenizer = AutoTokenizer.from_pretrained(model_name) 39 | 40 | if "mpt" in model_name.lower(): 41 | model = LlavaMPTForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() 42 | else: 43 | model = LlavaLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, torch_dtype=torch.float16, use_cache=True).cuda() 44 | image_processor = CLIPImageProcessor.from_pretrained(model.config.mm_vision_tower, torch_dtype=torch.float16) 45 | 46 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 47 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 48 | if mm_use_im_start_end: 49 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 50 | 51 | vision_tower = model.get_model().vision_tower[0] 52 | if vision_tower.device.type == 'meta': 53 | vision_tower = CLIPVisionModel.from_pretrained(vision_tower.config._name_or_path, torch_dtype=torch.float16, low_cpu_mem_usage=True).cuda() 54 | model.get_model().vision_tower[0] = vision_tower 55 | else: 56 | vision_tower.to(device='cuda', dtype=torch.float16) 57 | vision_config = vision_tower.config 58 | vision_config.im_patch_token = tokenizer.convert_tokens_to_ids([DEFAULT_IMAGE_PATCH_TOKEN])[0] 59 | vision_config.use_im_start_end = mm_use_im_start_end 60 | if mm_use_im_start_end: 61 | vision_config.im_start_token, vision_config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) 62 | image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2 63 | 64 | qs = args.query 65 | if mm_use_im_start_end: 66 | qs = qs + '\n' + DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN 67 | else: 68 | qs = qs + '\n' + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len 69 | 70 | if "v1" in model_name.lower(): 71 | conv_mode = "llava_v1" 72 | elif "mpt" in model_name.lower(): 73 | conv_mode = "mpt_multimodal" 74 | else: 75 | conv_mode = "multimodal" 76 | 77 | if args.conv_mode is not None and conv_mode != args.conv_mode: 78 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 79 | else: 80 | args.conv_mode = conv_mode 81 | 82 | conv = conv_templates[args.conv_mode].copy() 83 | conv.append_message(conv.roles[0], qs) 84 | conv.append_message(conv.roles[1], None) 85 | prompt = conv.get_prompt() 86 | inputs = tokenizer([prompt]) 87 | 88 | image = load_image(args.image_file) 89 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 90 | 91 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 92 | 93 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 94 | keywords = [stop_str] 95 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 96 | 97 | with torch.inference_mode(): 98 | output_ids = model.generate( 99 | input_ids, 100 | images=image_tensor.unsqueeze(0).half().cuda(), 101 | do_sample=True, 102 | temperature=0.2, 103 | max_new_tokens=1024, 104 | stopping_criteria=[stopping_criteria]) 105 | 106 | input_token_len = input_ids.shape[1] 107 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 108 | if n_diff_input_output > 0: 109 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 110 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 111 | outputs = outputs.strip() 112 | if outputs.endswith(stop_str): 113 | outputs = outputs[:-len(stop_str)] 114 | outputs = outputs.strip() 115 | print(outputs) 116 | 117 | if __name__ == "__main__": 118 | parser = argparse.ArgumentParser() 119 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 120 | parser.add_argument("--image-file", type=str, required=True) 121 | parser.add_argument("--query", type=str, required=True) 122 | parser.add_argument("--conv-mode", type=str, default=None) 123 | args = parser.parse_args() 124 | 125 | eval_model(args) 126 | -------------------------------------------------------------------------------- /llava/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | 8 | if __name__ == '__main__': 9 | base_dir = "vqa/reviews/coco2014_val80" 10 | review_files = [x for x in os.listdir(base_dir) if x.endswith('.jsonl') and x.startswith('gpt4_text')] 11 | 12 | for review_file in sorted(review_files): 13 | config = review_file.replace('gpt4_text_', '').replace('.jsonl', '') 14 | scores = defaultdict(list) 15 | print(f'GPT-4 vs. {config}') 16 | with open(os.path.join(base_dir, review_file)) as f: 17 | for review_str in f: 18 | review = json.loads(review_str) 19 | scores[review['category']].append(review['tuple']) 20 | scores['all'].append(review['tuple']) 21 | for k, v in scores.items(): 22 | stats = np.asarray(v).mean(0).tolist() 23 | stats = [round(x, 3) for x in stats] 24 | print(k, stats, round(stats[1]/stats[0]*100, 1)) 25 | print('=================================') 26 | 27 | -------------------------------------------------------------------------------- /llava/eval/table/model.jsonl: -------------------------------------------------------------------------------- 1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"} 2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"} 3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"} 4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"} 5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"} 6 | -------------------------------------------------------------------------------- /llava/eval/table/prompt.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"} 2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"} 3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"} 4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/table/reviewer.jsonl: -------------------------------------------------------------------------------- 1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"} 2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"} 3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/table/rule.json: -------------------------------------------------------------------------------- 1 | { 2 | "coding": {"role": "Assistant", "prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, 3 | "math": {"role": "Assistant", "prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, 4 | "default": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 5 | "conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 6 | "detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 7 | "complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."} 8 | } -------------------------------------------------------------------------------- /llava/eval/webpage/figures/alpaca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/eval/webpage/figures/alpaca.png -------------------------------------------------------------------------------- /llava/eval/webpage/figures/bard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/eval/webpage/figures/bard.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/chatgpt.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/llama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/eval/webpage/figures/llama.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/vicuna.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/eval/webpage/figures/vicuna.jpeg -------------------------------------------------------------------------------- /llava/eval/webpage/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots 7 | 8 | 9 | 10 | 11 | 12 | 13 | 32 | 33 |
34 |

Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots

35 | 36 | 37 |
38 |
39 | 40 | 41 |
42 |
43 | 44 | 45 |
46 |
47 |
48 |
49 | 50 | 51 |
52 |
53 |
54 | 55 | 56 |
57 |
58 | 59 |
60 |
61 |
62 | other logo 63 |
64 |
65 |
66 |
67 | 68 | 69 |
70 |
71 |
72 |
73 | vicuna logo 74 |
75 |
76 |
77 | 78 |
79 |
80 | 81 | 82 |
83 |
84 |
85 | 86 | 87 |
88 |
89 |
90 |
91 |
92 |
93 | 94 |
95 |
96 | 97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 | Assistant #2 (Vicuna, our model) 112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 | 123 | 124 |
125 |
GPT-4 Evaluation
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 | 135 |
136 |
137 | This website is co-authored with GPT-4. 138 |
139 |
140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /llava/eval/webpage/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 3 | background-color: #f8f9fa; 4 | } 5 | 6 | .navbar-dark .navbar-nav .nav-link { 7 | color: #f1cf68; 8 | font-size: 1.1rem; 9 | padding: 0.5rem 0.6rem; 10 | } 11 | 12 | .card-header { 13 | font-weight: bold; 14 | } 15 | 16 | .card { 17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); 18 | transition: 0.3s; 19 | } 20 | 21 | .card:hover { 22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); 23 | } 24 | 25 | button { 26 | transition: background-color 0.3s; 27 | } 28 | 29 | button:hover { 30 | background-color: #007bff; 31 | } 32 | 33 | @media (max-width: 767px) { 34 | .form-row .form-group { 35 | margin-bottom: 10px; 36 | } 37 | } 38 | 39 | /* Extra styles */ 40 | 41 | .expandable-card .card-text-container { 42 | max-height: 200px; 43 | overflow-y: hidden; 44 | position: relative; 45 | } 46 | 47 | .expandable-card.expanded .card-text-container { 48 | max-height: none; 49 | } 50 | 51 | .expand-btn { 52 | position: relative; 53 | display: none; 54 | background-color: rgba(255, 255, 255, 0.8); 55 | color: #510c75; 56 | border-color: transparent; 57 | } 58 | 59 | .expand-btn:hover { 60 | background-color: rgba(200, 200, 200, 0.8); 61 | text-decoration: none; 62 | border-color: transparent; 63 | color: #510c75; 64 | } 65 | 66 | .expand-btn:focus { 67 | outline: none; 68 | text-decoration: none; 69 | } 70 | 71 | .expandable-card:not(.expanded) .card-text-container:after { 72 | content: ""; 73 | position: absolute; 74 | bottom: 0; 75 | left: 0; 76 | width: 100%; 77 | height: 90px; 78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); 79 | } 80 | 81 | .expandable-card:not(.expanded) .expand-btn { 82 | margin-top: -40px; 83 | } 84 | 85 | .card-body { 86 | padding-bottom: 5px; 87 | } 88 | 89 | .vertical-flex-layout { 90 | justify-content: center; 91 | align-items: center; 92 | height: 100%; 93 | display: flex; 94 | flex-direction: column; 95 | gap: 5px; 96 | } 97 | 98 | .figure-img { 99 | max-width: 100%; 100 | height: auto; 101 | } 102 | 103 | .adjustable-font-size { 104 | font-size: calc(0.5rem + 2vw); 105 | } 106 | -------------------------------------------------------------------------------- /llava/model/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/model/.DS_Store -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .llava import LlavaLlamaForCausalLM, LlavaConfig 2 | from .llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig 3 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 4 | NUM_SENTINEL_TOKENS: int = 100 5 | 6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 7 | """Adds sentinel tokens and padding token (if missing). 8 | 9 | Expands the tokenizer vocabulary to include sentinel tokens 10 | used in mixture-of-denoiser tasks as well as a padding token. 11 | 12 | All added tokens are added as special tokens. No tokens are 13 | added if sentinel tokens and padding token already exist. 14 | """ 15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] 16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 17 | if tokenizer.pad_token is None: 18 | tokenizer.add_tokens('', special_tokens=True) 19 | tokenizer.pad_token = '' 20 | assert tokenizer.pad_token_id is not None 21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) 22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 23 | tokenizer.sentinel_token_ids = _sentinel_token_ids 24 | 25 | class AutoTokenizerForMOD(AutoTokenizer): 26 | """AutoTokenizer + Adaptation for MOD. 27 | 28 | A simple wrapper around AutoTokenizer to make instantiating 29 | an MOD-adapted tokenizer a bit easier. 30 | 31 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 32 | a padding token, and a property to get the token ids of the 33 | sentinel tokens. 34 | """ 35 | 36 | @classmethod 37 | def from_pretrained(cls, *args, **kwargs): 38 | """See `AutoTokenizer.from_pretrained` docstring.""" 39 | tokenizer = super().from_pretrained(*args, **kwargs) 40 | adapt_tokenizer_for_denoising(tokenizer) 41 | return tokenizer -------------------------------------------------------------------------------- /llava/model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs): 23 | del kwargs 24 | super().__init__() 25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 27 | self.norm_1 = norm_class(d_model, device=device) 28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device) 29 | self.norm_2 = norm_class(d_model, device=device) 30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 33 | 34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 35 | a = self.norm_1(x) 36 | (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 37 | x = x + self.resid_attn_dropout(b) 38 | m = self.norm_2(x) 39 | n = self.ffn(m) 40 | x = x + self.resid_ffn_dropout(n) 41 | return (x, past_key_value) -------------------------------------------------------------------------------- /llava/model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import torch 3 | import torch.nn as nn 4 | 5 | @contextmanager 6 | def init_empty_weights(include_buffers: bool=False): 7 | """Meta initialization context manager. 8 | 9 | A context manager under which models are initialized with all parameters 10 | on the meta device, therefore creating an empty model. Useful when just 11 | initializing the model would blow the available RAM. 12 | 13 | Args: 14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 15 | not to also put all buffers on the meta device while initializing. 16 | 17 | Example: 18 | ```python 19 | import torch.nn as nn 20 | 21 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 22 | with init_empty_weights(): 23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 24 | ``` 25 | 26 | 27 | 28 | Any model created under this context manager has no weights. As such you can't do something like 29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 30 | 31 | 32 | """ 33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: 34 | yield f 35 | 36 | @contextmanager 37 | def init_on_device(device: torch.device, include_buffers: bool=False): 38 | """Device initialization context manager. 39 | 40 | A context manager under which models are initialized with all parameters 41 | on the specified device. 42 | 43 | Args: 44 | device (`torch.device`): Device to initialize all parameters on. 45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 46 | not to also put all buffers on the meta device while initializing. 47 | 48 | Example: 49 | ```python 50 | import torch.nn as nn 51 | 52 | with init_on_device(device=torch.device("cuda")): 53 | tst = nn.Liner(100, 100) # on `cuda` device 54 | ``` 55 | """ 56 | old_register_parameter = nn.Module.register_parameter 57 | if include_buffers: 58 | old_register_buffer = nn.Module.register_buffer 59 | 60 | def register_empty_parameter(module, name, param): 61 | old_register_parameter(module, name, param) 62 | if param is not None: 63 | param_cls = type(module._parameters[name]) 64 | kwargs = module._parameters[name].__dict__ 65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 66 | 67 | def register_empty_buffer(module, name, buffer): 68 | old_register_buffer(module, name, buffer) 69 | if buffer is not None: 70 | module._buffers[name] = module._buffers[name].to(device) 71 | if include_buffers: 72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} 73 | else: 74 | tensor_constructors_to_patch = {} 75 | 76 | def patch_tensor_constructor(fn): 77 | 78 | def wrapper(*args, **kwargs): 79 | kwargs['device'] = device 80 | return fn(*args, **kwargs) 81 | return wrapper 82 | try: 83 | nn.Module.register_parameter = register_empty_parameter 84 | if include_buffers: 85 | nn.Module.register_buffer = register_empty_buffer 86 | for torch_function_name in tensor_constructors_to_patch.keys(): 87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 88 | yield 89 | finally: 90 | nn.Module.register_parameter = old_register_parameter 91 | if include_buffers: 92 | nn.Module.register_buffer = old_register_buffer 93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): 94 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /llava/model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x / torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | if weight is not None: 30 | return output * weight 31 | return output 32 | 33 | class RMSNorm(torch.nn.Module): 34 | 35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 36 | super().__init__() 37 | self.eps = eps 38 | if weight: 39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 40 | else: 41 | self.register_parameter('weight', None) 42 | 43 | def forward(self, x): 44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 45 | 46 | class LPRMSNorm(RMSNorm): 47 | 48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 50 | 51 | def forward(self, x): 52 | downcast_x = _cast_if_autocast_enabled(x) 53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 54 | with torch.autocast(enabled=False, device_type=x.device.type): 55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from llava.model import * 3 | from transformers import AutoConfig, StoppingCriteria 4 | 5 | 6 | def auto_upgrade(config): 7 | cfg = AutoConfig.from_pretrained(config) 8 | if 'llava' in config and 'llava' not in cfg.model_type: 9 | assert cfg.model_type == 'llama' 10 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 11 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 12 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 13 | if confirm.lower() in ["y", "yes"]: 14 | print("Upgrading checkpoint...") 15 | assert len(cfg.architectures) == 1 16 | setattr(cfg.__class__, "model_type", "llava") 17 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 18 | cfg.save_pretrained(config) 19 | print("Checkpoint upgraded.") 20 | else: 21 | print("Checkpoint upgrade aborted.") 22 | exit(1) 23 | 24 | 25 | 26 | class KeywordsStoppingCriteria(StoppingCriteria): 27 | def __init__(self, keywords, tokenizer, input_ids): 28 | self.keywords = keywords 29 | self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] 30 | self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1] 31 | self.tokenizer = tokenizer 32 | self.start_len = None 33 | self.input_ids = input_ids 34 | 35 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 36 | if self.start_len is None: 37 | self.start_len = self.input_ids.shape[1] 38 | else: 39 | for keyword_id in self.keyword_ids: 40 | if output_ids[0, -1] == keyword_id: 41 | return True 42 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 43 | for keyword in self.keywords: 44 | if keyword in outputs: 45 | return True 46 | return False 47 | -------------------------------------------------------------------------------- /llava/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llava" 7 | version = "1.0.1" 8 | description = "Towards GPT-4 like large language and visual assistant." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "einops", "fastapi", "gradio==3.35.2", "markdown2[all]", "numpy", 17 | "requests", "sentencepiece", "tokenizers>=0.12.1", 18 | "torch", "torchvision", "uvicorn", "wandb", 19 | "shortuuid", "httpx==0.24.0", 20 | "deepspeed==0.9.5", 21 | "peft==0.4.0", 22 | "transformers==4.31.0", 23 | "accelerate==0.21.0", 24 | "bitsandbytes==0.41.0", 25 | "scikit-learn==1.2.2", 26 | "sentencepiece==0.1.99", 27 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 28 | "gradio_client==0.2.9", 29 | "ipykernel" # for jupyter notebook 30 | ] 31 | 32 | [project.urls] 33 | "Homepage" = "https://llava-vl.github.io" 34 | "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues" 35 | 36 | [tool.setuptools.packages.find] 37 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 38 | 39 | [tool.wheel] 40 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 41 | -------------------------------------------------------------------------------- /llava/serve/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/.DS_Store -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.serve.cli --model ~/model_weights/llama-7b 4 | """ 5 | import argparse 6 | import time 7 | 8 | import torch 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | 11 | from llava.conversation import conv_templates, SeparatorStyle 12 | 13 | 14 | @torch.inference_mode() 15 | def generate_stream(tokenizer, model, params, device, 16 | context_len=2048, stream_interval=2): 17 | """Adapted from fastchat/serve/model_worker.py::generate_stream""" 18 | 19 | prompt = params["prompt"] 20 | l_prompt = len(prompt) 21 | temperature = float(params.get("temperature", 1.0)) 22 | max_new_tokens = int(params.get("max_new_tokens", 256)) 23 | stop_str = params.get("stop", None) 24 | 25 | input_ids = tokenizer(prompt).input_ids 26 | output_ids = list(input_ids) 27 | 28 | max_src_len = context_len - max_new_tokens - 8 29 | input_ids = input_ids[-max_src_len:] 30 | 31 | for i in range(max_new_tokens): 32 | if i == 0: 33 | out = model( 34 | torch.as_tensor([input_ids], device=device), use_cache=True) 35 | logits = out.logits 36 | past_key_values = out.past_key_values 37 | else: 38 | attention_mask = torch.ones( 39 | 1, past_key_values[0][0].shape[-2] + 1, device=device) 40 | out = model(input_ids=torch.as_tensor([[token]], device=device), 41 | use_cache=True, 42 | attention_mask=attention_mask, 43 | past_key_values=past_key_values) 44 | logits = out.logits 45 | past_key_values = out.past_key_values 46 | 47 | last_token_logits = logits[0][-1] 48 | if temperature < 1e-4: 49 | token = int(torch.argmax(last_token_logits)) 50 | else: 51 | probs = torch.softmax(last_token_logits / temperature, dim=-1) 52 | token = int(torch.multinomial(probs, num_samples=1)) 53 | 54 | output_ids.append(token) 55 | 56 | if token == tokenizer.eos_token_id: 57 | stopped = True 58 | else: 59 | stopped = False 60 | 61 | if i % stream_interval == 0 or i == max_new_tokens - 1 or stopped: 62 | output = tokenizer.decode(output_ids, skip_special_tokens=True) 63 | pos = output.rfind(stop_str, l_prompt) 64 | if pos != -1: 65 | output = output[:pos] 66 | stopped = True 67 | yield output 68 | 69 | if stopped: 70 | break 71 | 72 | del past_key_values 73 | 74 | 75 | def main(args): 76 | model_name = args.model_name 77 | num_gpus = args.num_gpus 78 | 79 | # Model 80 | if args.device == "cuda": 81 | kwargs = {"torch_dtype": torch.float16} 82 | if num_gpus == "auto": 83 | kwargs["device_map"] = "auto" 84 | else: 85 | num_gpus = int(num_gpus) 86 | if num_gpus != 1: 87 | kwargs.update({ 88 | "device_map": "auto", 89 | "max_memory": {i: "13GiB" for i in range(num_gpus)}, 90 | }) 91 | elif args.device == "cpu": 92 | kwargs = {} 93 | else: 94 | raise ValueError(f"Invalid device: {args.device}") 95 | 96 | tokenizer = AutoTokenizer.from_pretrained(model_name) 97 | model = AutoModelForCausalLM.from_pretrained(model_name, 98 | low_cpu_mem_usage=True, **kwargs) 99 | 100 | if args.device == "cuda" and num_gpus == 1: 101 | model.cuda() 102 | 103 | # Chat 104 | conv = conv_templates[args.conv_template].copy() 105 | while True: 106 | try: 107 | inp = input(f"{conv.roles[0]}: ") 108 | except EOFError: 109 | inp = "" 110 | if not inp: 111 | print("exit...") 112 | break 113 | 114 | conv.append_message(conv.roles[0], inp) 115 | conv.append_message(conv.roles[1], None) 116 | prompt = conv.get_prompt() 117 | 118 | params = { 119 | "model": model_name, 120 | "prompt": prompt, 121 | "temperature": args.temperature, 122 | "max_new_tokens": args.max_new_tokens, 123 | "stop": conv.sep if conv.sep_style == SeparatorStyle.SINGLE else conv.sep2, 124 | } 125 | 126 | print(f"{conv.roles[1]}: ", end="", flush=True) 127 | pre = 0 128 | for outputs in generate_stream(tokenizer, model, params, args.device): 129 | outputs = outputs[len(prompt) + 1:].strip() 130 | outputs = outputs.split(" ") 131 | now = len(outputs) 132 | if now - 1 > pre: 133 | print(" ".join(outputs[pre:now-1]), end=" ", flush=True) 134 | pre = now - 1 135 | print(" ".join(outputs[pre:]), flush=True) 136 | 137 | conv.messages[-1][-1] = " ".join(outputs) 138 | 139 | if args.debug: 140 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 141 | 142 | 143 | if __name__ == "__main__": 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 146 | parser.add_argument("--num-gpus", type=str, default="1") 147 | parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda") 148 | parser.add_argument("--conv-template", type=str, default="v1") 149 | parser.add_argument("--temperature", type=float, default=0.7) 150 | parser.add_argument("--max-new-tokens", type=int, default=512) 151 | parser.add_argument("--debug", action="store_true") 152 | args = parser.parse_args() 153 | main(args) 154 | -------------------------------------------------------------------------------- /llava/serve/examples/CornellTech.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/CornellTech.png -------------------------------------------------------------------------------- /llava/serve/examples/crying_boy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/crying_boy.jpg -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/london.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/london.jpg -------------------------------------------------------------------------------- /llava/serve/examples/math.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/math.jpg -------------------------------------------------------------------------------- /llava/serve/examples/math.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/math.png -------------------------------------------------------------------------------- /llava/serve/examples/poo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/poo.jpg -------------------------------------------------------------------------------- /llava/serve/examples/porsche911.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/porsche911.jpeg -------------------------------------------------------------------------------- /llava/serve/examples/porsche911.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/porsche911.png -------------------------------------------------------------------------------- /llava/serve/examples/steak.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/steak.jpg -------------------------------------------------------------------------------- /llava/serve/examples/thief.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/thief.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/gateway/README.md: -------------------------------------------------------------------------------- 1 | # fastchat Nginx Gateway 2 | 3 | ## Purpose of the Gateway 4 | 5 | The Nginx gateway serves the following purposes: 6 | 7 | 1. Protects Gradio servers by acting as a firewall. 8 | 2. Facilitates dynamic mounting and unmounting of Gradio servers. 9 | 3. Provides load balancing for Gradio servers. 10 | 4. Offers additional security features, such as total connection limit. 11 | 5. Reduces attack surface by requiring only a single public port to be exposed for serving. 12 | 13 | ## Deployment and Updating of the Gateway 14 | 15 | ### Installing Nginx 16 | 17 | On Debian-based distributions (e.g., Ubuntu): 18 | 19 | ```bash 20 | sudo apt update 21 | sudo apt install nginx 22 | ``` 23 | On Red Hat-based distributions (e.g., CentOS, Fedora): 24 | 25 | ```bash 26 | sudo yum install epel-release 27 | sudo yum install nginx 28 | ``` 29 | 30 | ### Deployment 31 | 32 | Copy `nginx.conf` to `/etc/nginx/nginx.conf` (need sudo permission). 33 | 34 | Replace the port number 7860 in `server localhost:7860` with the port where you deploy the Gradio web server. 35 | 36 | Modify `upstream websocket` to configure Gradio servers behind the gateway. 37 | 38 | Lastly, update Nginx. 39 | 40 | 41 | ### HTTPS Deployment with a Public Domain URL 42 | 43 | Make sure you obtain the HTTPS certificate and the private key used to generate the certificate. 44 | 45 | Fill the addresses to your certificate and private key in the `[PATH_TO_SSL_CERT]` and `[PATH_TO_PRIVATE_KEY]` fields. 46 | 47 | If you have your own domain url to serve the chatbot, replace the chat.lmsys.org url with your own domain url. 48 | 49 | ### Updating 50 | 51 | Every time when `/etc/nginx/nginx.conf` is modified, you need to update the Nginx service: 52 | 53 | ```bash 54 | sudo nginx -t # check `/etc/nginx/nginx.conf` 55 | sudo systemctl reload nginx # restart Nginx service to load the new config 56 | sudo systemctl status nginx # check the status of the Nginx service. It should be active (running). 57 | ``` 58 | -------------------------------------------------------------------------------- /llava/serve/gateway/nginx.conf: -------------------------------------------------------------------------------- 1 | user www-data; 2 | worker_processes auto; 3 | pid /run/nginx.pid; 4 | include /etc/nginx/modules-enabled/*.conf; 5 | 6 | events { 7 | worker_connections 1024; # maximum number of connections that a worker process can handle concurrently 8 | # multi_accept on; # enabling multi_accept can help improve performance under high load, but may increase the number of simultaneous connections that a worker process can handle 9 | 10 | } 11 | 12 | http { 13 | ## 14 | # Basic Settings 15 | ## 16 | 17 | sendfile on; # enable sendfile for performance optimization 18 | tcp_nopush on; # enable TCP no-pushing 19 | tcp_nodelay on; # enable TCP no-delay 20 | keepalive_timeout 65; # sets the timeout for keep-alive connections 21 | types_hash_max_size 2048; # maximum size of the types hash table 22 | # server_tokens off; # disable server token (i.e., server signature) in response headers to improve security 23 | 24 | # server_names_hash_bucket_size 64; 25 | # server_name_in_redirect off; 26 | 27 | include /etc/nginx/mime.types; # include MIME types file 28 | default_type application/octet-stream; # default MIME type for unknown file types 29 | 30 | ## 31 | # SSL Settings 32 | ## 33 | 34 | ssl_protocols TLSv1.2; # specify SSL/TLS protocols to use 35 | ssl_prefer_server_ciphers on; # prefer server ciphers over client ciphers 36 | 37 | ## 38 | # Logging Settings 39 | ## 40 | 41 | access_log /var/log/nginx/access.log; # path to access log file 42 | error_log /var/log/nginx/error.log; # path to error log file 43 | 44 | ## 45 | # Gzip Settings 46 | ## 47 | gzip on; # enable Gzip compression 48 | 49 | ## 50 | # Virtual Host Configs 51 | ## 52 | 53 | include /etc/nginx/conf.d/*.conf; # include all configuration files in conf.d directory 54 | include /etc/nginx/sites-enabled/*; # include all enabled sites configuration files 55 | 56 | # WebSocket Proxy: https://www.nginx.com/blog/websocket-nginx/ 57 | map $http_upgrade $connection_upgrade { 58 | default upgrade; 59 | '' close; 60 | } 61 | 62 | upstream websocket { 63 | ip_hash; # load balancing by IP to guarantee session persistence 64 | server localhost:7860; # The port should be the gradio web server port 65 | # server localhost:7861; # extra gradio server if more than one 66 | } 67 | 68 | limit_conn_status 429; 69 | limit_conn_zone $binary_remote_addr zone=perip:10m; # limit number of connections per IP 70 | limit_conn_zone $server_name zone=perserver:10m; # limit number of connections per server 71 | 72 | server { 73 | listen 443 ssl; # the listening port of our server 74 | ssl_certificate [PATH_TO_SSL_CERT]; 75 | ssl_certificate_key [PATH_TO_PRIVATE_KEY]; 76 | server_name chat.lmsys.org; # replace the url with your own domain url 77 | limit_conn perserver 1024; # connections per server 78 | location / { 79 | proxy_pass http://websocket; # proxy all requests to the defined upstream server 80 | limit_conn perip 5; # connections per IP 81 | proxy_set_header Host $host; # set the Host header for the upstream server 82 | proxy_set_header X-Real-IP $remote_addr; # set the client IP address as the real IP for the upstream server 83 | proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; # set the client IP addresses in the X-Forwarded-For header 84 | proxy_http_version 1.1; # use HTTP version 1.1 for upstream communication 85 | proxy_set_header Upgrade $http_upgrade; 86 | proxy_set_header Connection "Upgrade"; # set the Connection header to Upgrade to enable WebSocket communication 87 | } 88 | } 89 | 90 | # the following block routes all HTTP traffic to HTTPS via nginx 91 | server { 92 | listen 80; 93 | server_name chat.lmsys.org; 94 | return 301 https://chat.lmsys.org$request_uri; 95 | } 96 | 97 | } 98 | -------------------------------------------------------------------------------- /llava/serve/gradio_css.py: -------------------------------------------------------------------------------- 1 | code_highlight_css = ( 2 | """ 3 | #chatbot .hll { background-color: #ffffcc } 4 | #chatbot .c { color: #408080; font-style: italic } 5 | #chatbot .err { border: 1px solid #FF0000 } 6 | #chatbot .k { color: #008000; font-weight: bold } 7 | #chatbot .o { color: #666666 } 8 | #chatbot .ch { color: #408080; font-style: italic } 9 | #chatbot .cm { color: #408080; font-style: italic } 10 | #chatbot .cp { color: #BC7A00 } 11 | #chatbot .cpf { color: #408080; font-style: italic } 12 | #chatbot .c1 { color: #408080; font-style: italic } 13 | #chatbot .cs { color: #408080; font-style: italic } 14 | #chatbot .gd { color: #A00000 } 15 | #chatbot .ge { font-style: italic } 16 | #chatbot .gr { color: #FF0000 } 17 | #chatbot .gh { color: #000080; font-weight: bold } 18 | #chatbot .gi { color: #00A000 } 19 | #chatbot .go { color: #888888 } 20 | #chatbot .gp { color: #000080; font-weight: bold } 21 | #chatbot .gs { font-weight: bold } 22 | #chatbot .gu { color: #800080; font-weight: bold } 23 | #chatbot .gt { color: #0044DD } 24 | #chatbot .kc { color: #008000; font-weight: bold } 25 | #chatbot .kd { color: #008000; font-weight: bold } 26 | #chatbot .kn { color: #008000; font-weight: bold } 27 | #chatbot .kp { color: #008000 } 28 | #chatbot .kr { color: #008000; font-weight: bold } 29 | #chatbot .kt { color: #B00040 } 30 | #chatbot .m { color: #666666 } 31 | #chatbot .s { color: #BA2121 } 32 | #chatbot .na { color: #7D9029 } 33 | #chatbot .nb { color: #008000 } 34 | #chatbot .nc { color: #0000FF; font-weight: bold } 35 | #chatbot .no { color: #880000 } 36 | #chatbot .nd { color: #AA22FF } 37 | #chatbot .ni { color: #999999; font-weight: bold } 38 | #chatbot .ne { color: #D2413A; font-weight: bold } 39 | #chatbot .nf { color: #0000FF } 40 | #chatbot .nl { color: #A0A000 } 41 | #chatbot .nn { color: #0000FF; font-weight: bold } 42 | #chatbot .nt { color: #008000; font-weight: bold } 43 | #chatbot .nv { color: #19177C } 44 | #chatbot .ow { color: #AA22FF; font-weight: bold } 45 | #chatbot .w { color: #bbbbbb } 46 | #chatbot .mb { color: #666666 } 47 | #chatbot .mf { color: #666666 } 48 | #chatbot .mh { color: #666666 } 49 | #chatbot .mi { color: #666666 } 50 | #chatbot .mo { color: #666666 } 51 | #chatbot .sa { color: #BA2121 } 52 | #chatbot .sb { color: #BA2121 } 53 | #chatbot .sc { color: #BA2121 } 54 | #chatbot .dl { color: #BA2121 } 55 | #chatbot .sd { color: #BA2121; font-style: italic } 56 | #chatbot .s2 { color: #BA2121 } 57 | #chatbot .se { color: #BB6622; font-weight: bold } 58 | #chatbot .sh { color: #BA2121 } 59 | #chatbot .si { color: #BB6688; font-weight: bold } 60 | #chatbot .sx { color: #008000 } 61 | #chatbot .sr { color: #BB6688 } 62 | #chatbot .s1 { color: #BA2121 } 63 | #chatbot .ss { color: #19177C } 64 | #chatbot .bp { color: #008000 } 65 | #chatbot .fm { color: #0000FF } 66 | #chatbot .vc { color: #19177C } 67 | #chatbot .vg { color: #19177C } 68 | #chatbot .vi { color: #19177C } 69 | #chatbot .vm { color: #19177C } 70 | #chatbot .il { color: #666666 } 71 | """) 72 | #.highlight { background: #f8f8f8; } 73 | 74 | -------------------------------------------------------------------------------- /llava/serve/gradio_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adopted from https://github.com/gradio-app/gradio/blob/main/gradio/components.py 3 | Fix a markdown render problem. 4 | """ 5 | from __future__ import annotations 6 | 7 | from gradio.components import * 8 | from markdown2 import Markdown 9 | 10 | 11 | class _Keywords(Enum): 12 | NO_VALUE = "NO_VALUE" # Used as a sentinel to determine if nothing is provided as a argument for `value` in `Component.update()` 13 | FINISHED_ITERATING = "FINISHED_ITERATING" # Used to skip processing of a component's value (needed for generators + state) 14 | 15 | 16 | @document("style") 17 | class Chatbot(Changeable, Selectable, IOComponent, JSONSerializable): 18 | """ 19 | Displays a chatbot output showing both user submitted messages and responses. Supports a subset of Markdown including bold, italics, code, and images. 20 | Preprocessing: this component does *not* accept input. 21 | Postprocessing: expects function to return a {List[Tuple[str | None | Tuple, str | None | Tuple]]}, a list of tuples with user message and response messages. Messages should be strings, tuples, or Nones. If the message is a string, it can include Markdown. If it is a tuple, it should consist of (string filepath to image/video/audio, [optional string alt text]). Messages that are `None` are not displayed. 22 | 23 | Demos: chatbot_simple, chatbot_multimodal 24 | """ 25 | 26 | def __init__( 27 | self, 28 | value: List[Tuple[str | None, str | None]] | Callable | None = None, 29 | color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style() 30 | *, 31 | label: str | None = None, 32 | every: float | None = None, 33 | show_label: bool = True, 34 | visible: bool = True, 35 | elem_id: str | None = None, 36 | elem_classes: List[str] | str | None = None, 37 | **kwargs, 38 | ): 39 | """ 40 | Parameters: 41 | value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component. 42 | label: component name in interface. 43 | every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. 44 | show_label: if True, will display label. 45 | visible: If False, component will be hidden. 46 | elem_id: An optional string that is assigned as the id of this component in the HTML DOM. Can be used for targeting CSS styles. 47 | elem_classes: An optional list of strings that are assigned as the classes of this component in the HTML DOM. Can be used for targeting CSS styles. 48 | """ 49 | if color_map is not None: 50 | warnings.warn( 51 | "The 'color_map' parameter has been deprecated.", 52 | ) 53 | #self.md = utils.get_markdown_parser() 54 | self.md = Markdown(extras=["fenced-code-blocks", "tables", "break-on-newline"]) 55 | self.select: EventListenerMethod 56 | """ 57 | Event listener for when the user selects message from Chatbot. 58 | Uses event data gradio.SelectData to carry `value` referring to text of selected message, and `index` tuple to refer to [message, participant] index. 59 | See EventData documentation on how to use this event data. 60 | """ 61 | 62 | IOComponent.__init__( 63 | self, 64 | label=label, 65 | every=every, 66 | show_label=show_label, 67 | visible=visible, 68 | elem_id=elem_id, 69 | elem_classes=elem_classes, 70 | value=value, 71 | **kwargs, 72 | ) 73 | 74 | def get_config(self): 75 | return { 76 | "value": self.value, 77 | "selectable": self.selectable, 78 | **IOComponent.get_config(self), 79 | } 80 | 81 | @staticmethod 82 | def update( 83 | value: Any | Literal[_Keywords.NO_VALUE] | None = _Keywords.NO_VALUE, 84 | label: str | None = None, 85 | show_label: bool | None = None, 86 | visible: bool | None = None, 87 | ): 88 | updated_config = { 89 | "label": label, 90 | "show_label": show_label, 91 | "visible": visible, 92 | "value": value, 93 | "__type__": "update", 94 | } 95 | return updated_config 96 | 97 | def _process_chat_messages( 98 | self, chat_message: str | Tuple | List | Dict | None 99 | ) -> str | Dict | None: 100 | if chat_message is None: 101 | return None 102 | elif isinstance(chat_message, (tuple, list)): 103 | mime_type = processing_utils.get_mimetype(chat_message[0]) 104 | return { 105 | "name": chat_message[0], 106 | "mime_type": mime_type, 107 | "alt_text": chat_message[1] if len(chat_message) > 1 else None, 108 | "data": None, # These last two fields are filled in by the frontend 109 | "is_file": True, 110 | } 111 | elif isinstance( 112 | chat_message, dict 113 | ): # This happens for previously processed messages 114 | return chat_message 115 | elif isinstance(chat_message, str): 116 | #return self.md.render(chat_message) 117 | return str(self.md.convert(chat_message)) 118 | else: 119 | raise ValueError(f"Invalid message for Chatbot component: {chat_message}") 120 | 121 | def postprocess( 122 | self, 123 | y: List[ 124 | Tuple[str | Tuple | List | Dict | None, str | Tuple | List | Dict | None] 125 | ], 126 | ) -> List[Tuple[str | Dict | None, str | Dict | None]]: 127 | """ 128 | Parameters: 129 | y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. It can also be a tuple whose first element is a string filepath or URL to an image/video/audio, and second (optional) element is the alt text, in which case the media file is displayed. It can also be None, in which case that message is not displayed. 130 | Returns: 131 | List of tuples representing the message and response. Each message and response will be a string of HTML, or a dictionary with media information. 132 | """ 133 | if y is None: 134 | return [] 135 | processed_messages = [] 136 | for message_pair in y: 137 | assert isinstance( 138 | message_pair, (tuple, list) 139 | ), f"Expected a list of lists or list of tuples. Received: {message_pair}" 140 | assert ( 141 | len(message_pair) == 2 142 | ), f"Expected a list of lists of length 2 or list of tuples of length 2. Received: {message_pair}" 143 | processed_messages.append( 144 | ( 145 | #self._process_chat_messages(message_pair[0]), 146 | '
' +
147 |                     message_pair[0] + "
", 148 | self._process_chat_messages(message_pair[1]), 149 | ) 150 | ) 151 | return processed_messages 152 | 153 | def style(self, height: int | None = None, **kwargs): 154 | """ 155 | This method can be used to change the appearance of the Chatbot component. 156 | """ 157 | if height is not None: 158 | self._style["height"] = height 159 | if kwargs.get("color_map") is not None: 160 | warnings.warn("The 'color_map' parameter has been deprecated.") 161 | 162 | Component.style( 163 | self, 164 | **kwargs, 165 | ) 166 | return self 167 | 168 | 169 | -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | from typing import List, Optional, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | 7 | import transformers 8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 9 | 10 | from einops import rearrange 11 | 12 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | def forward( 16 | self, 17 | hidden_states: torch.Tensor, 18 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | output_attentions: bool = False, 21 | use_cache: bool = False, 22 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], 23 | Optional[Tuple[torch.Tensor]]]: 24 | """Input shape: Batch x Time x Channel 25 | 26 | attention_mask: [bsz, q_len] 27 | """ 28 | bsz, q_len, _ = hidden_states.size() 29 | 30 | query_states = self.q_proj(hidden_states).view( 31 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 32 | key_states = self.k_proj(hidden_states).view( 33 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 34 | value_states = self.v_proj(hidden_states).view( 35 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 36 | # [bsz, q_len, nh, hd] 37 | # [bsz, nh, q_len, hd] 38 | 39 | kv_seq_len = key_states.shape[-2] 40 | offset = 0 41 | if past_key_value is not None: 42 | offset = past_key_value[0].shape[-2] 43 | kv_seq_len += offset 44 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 45 | query_states, key_states = apply_rotary_pos_emb(query_states, 46 | key_states, 47 | cos, 48 | sin, 49 | offset=offset) 50 | # [bsz, nh, t, hd] 51 | assert not output_attentions, "output_attentions is not supported" 52 | assert not use_cache, "use_cache is not supported" 53 | assert past_key_value is None, "past_key_value is not supported" 54 | 55 | # Flash attention codes from 56 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 57 | 58 | # transform the data into the format required by flash attention 59 | qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] 60 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 61 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 62 | # the attention_mask should be the same as the key_padding_mask 63 | key_padding_mask = attention_mask 64 | 65 | 66 | if key_padding_mask is None: 67 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 68 | max_s = q_len 69 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, 70 | device=qkv.device) 71 | output = flash_attn_unpadded_qkvpacked_func( 72 | qkv, cu_q_lens, max_s, 0.0, 73 | softmax_scale=None, causal=True 74 | ) 75 | output = rearrange(output, '(b s) ... -> b s ...', b=bsz) 76 | else: 77 | nheads = qkv.shape[-2] 78 | x = rearrange(qkv, 'b s three h d -> b s (three h d)') 79 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 80 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) 81 | output_unpad = flash_attn_unpadded_qkvpacked_func( 82 | x_unpad, cu_q_lens, max_s, 0.0, 83 | softmax_scale=None, causal=True 84 | ) 85 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), 86 | indices, bsz, q_len), 87 | 'b s (h d) -> b s h d', h=nheads) 88 | return self.o_proj(rearrange(output, 89 | 'b s h d -> b s (h d)')), None, None 90 | 91 | 92 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 93 | # requires the attention mask to be the same as the key_padding_mask 94 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, 95 | inputs_embeds, past_key_values_length): 96 | # [bsz, seq_len] 97 | return attention_mask 98 | 99 | 100 | def replace_llama_attn_with_flash_attn(): 101 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 102 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 103 | -------------------------------------------------------------------------------- /llava/train/llava_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from transformers import Trainer 6 | from typing import Dict, Optional, Sequence 7 | 8 | 9 | def unwrap_model(model: nn.Module) -> nn.Module: 10 | """ 11 | Recursively unwraps a model from potential containers (as used in distributed training). 12 | 13 | Args: 14 | model (`torch.nn.Module`): The model to unwrap. 15 | """ 16 | # since there could be multiple levels of wrapping, unwrap recursively 17 | if hasattr(model, "module"): 18 | return unwrap_model(model.module) 19 | else: 20 | return model 21 | 22 | 23 | class LLaVATrainer(Trainer): 24 | 25 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 26 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 27 | # Save the model 28 | _state_dict = state_dict 29 | if _state_dict is None: 30 | # Only save the model itself if we are using distributed training 31 | model_to_save = unwrap_model(self.model) 32 | _state_dict = model_to_save.state_dict() 33 | 34 | weight_to_save = {} 35 | keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in'] 36 | for k, v in _state_dict.items(): 37 | if any(key_match in k for key_match in keys_to_match): 38 | weight_to_save[k] = v 39 | 40 | current_folder = output_dir.split('/')[-1] 41 | parent_folder = os.path.dirname(output_dir) 42 | if current_folder.startswith('checkpoint-'): 43 | mm_projector_folder = os.path.join(parent_folder, "mm_projector") 44 | os.makedirs(mm_projector_folder, exist_ok=True) 45 | torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) 46 | else: 47 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 48 | 49 | super(LLaVATrainer, self)._save(output_dir, state_dict) 50 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/models/.gitkeep -------------------------------------------------------------------------------- /original_images/CT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/original_images/CT.png -------------------------------------------------------------------------------- /original_images/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/original_images/waterview.jpg -------------------------------------------------------------------------------- /pandagpt/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/.DS_Store -------------------------------------------------------------------------------- /pandagpt/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/.gitkeep -------------------------------------------------------------------------------- /pandagpt/code/assets/audios/bird_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/audios/bird_audio.wav -------------------------------------------------------------------------------- /pandagpt/code/assets/audios/car_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/audios/car_audio.wav -------------------------------------------------------------------------------- /pandagpt/code/assets/audios/dog_audio.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/audios/dog_audio.wav -------------------------------------------------------------------------------- /pandagpt/code/assets/images/bird_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/images/bird_image.jpg -------------------------------------------------------------------------------- /pandagpt/code/assets/images/car_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/images/car_image.jpg -------------------------------------------------------------------------------- /pandagpt/code/assets/images/dog_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/images/dog_image.jpg -------------------------------------------------------------------------------- /pandagpt/code/assets/thermals/190662.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/thermals/190662.jpg -------------------------------------------------------------------------------- /pandagpt/code/assets/thermals/210009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/thermals/210009.jpg -------------------------------------------------------------------------------- /pandagpt/code/assets/videos/a.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/videos/a.mp4 -------------------------------------------------------------------------------- /pandagpt/code/assets/videos/world.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/assets/videos/world.mp4 -------------------------------------------------------------------------------- /pandagpt/code/config/__init__.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | 3 | def load_model_config(model, mode): 4 | # load special config for each model 5 | config_path = f'config/{model}.yaml' 6 | print(f'[!] load configuration from {config_path}') 7 | with open(config_path) as f: 8 | configuration = yaml.load(f, Loader=yaml.FullLoader) 9 | new_config = {} 10 | for key, value in configuration.items(): 11 | if key in ['train', 'test', 'validation']: 12 | if mode == key: 13 | new_config.update(value) 14 | else: 15 | new_config[key] = value 16 | configuration = new_config 17 | return configuration 18 | 19 | def load_config(args): 20 | '''the configuration of each model can rewrite the base configuration''' 21 | # base config 22 | base_configuration = load_base_config() 23 | 24 | # load one model config 25 | configuration = load_model_config(args['model'], args['mode']) 26 | 27 | # update and append the special config for base config 28 | base_configuration.update(configuration) 29 | configuration = base_configuration 30 | return configuration 31 | 32 | def load_base_config(): 33 | config_path = f'config/base.yaml' 34 | with open(config_path) as f: 35 | configuration = yaml.load(f, Loader=yaml.FullLoader) 36 | print(f'[!] load base configuration: {config_path}') 37 | return configuration 38 | -------------------------------------------------------------------------------- /pandagpt/code/config/base.yaml: -------------------------------------------------------------------------------- 1 | models: 2 | openllama: 3 | model_name: OpenLLAMAModel 4 | agent_name: DeepSpeedAgent 5 | stage1_train_dataset: SupervisedDataset 6 | test_dataset: SelfInstructTestDataset 7 | openllama_peft: 8 | model_name: OpenLLAMAPEFTModel 9 | agent_name: DeepSpeedAgent 10 | stage1_train_dataset: SupervisedDataset 11 | test_dataset: SelfInstructTestDataset 12 | 13 | # ========= Global configuration ========== # 14 | logging_step: 5 15 | # ========= Global configuration ========== # 16 | -------------------------------------------------------------------------------- /pandagpt/code/config/openllama_peft.yaml: -------------------------------------------------------------------------------- 1 | # generation hyper-parameters 2 | max_len: 512 3 | penalty_alpha: 0.6 4 | top_k: 10 5 | top_p: 0.7 6 | random_prefix_len: 5 7 | sample_num: 2 8 | decoding_method: sampling 9 | generate_len: 512 10 | 11 | # lora hyper-parameters 12 | lora_r: 32 13 | lora_alpha: 32 14 | lora_dropout: 0.1 15 | 16 | # some train configuration, more can be found under dsconfig folder 17 | train: 18 | seed: 0 19 | warmup_rate: 0.1 20 | epochs: 2 21 | max_length: 1024 22 | max_shard_size: 10GB 23 | -------------------------------------------------------------------------------- /pandagpt/code/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from header import * 2 | from .samplers import DistributedBatchSampler 3 | from .sft_dataset import * 4 | 5 | ''' 6 | def get_tokenizer(model): 7 | tokenizer = LlamaTokenizer.from_pretrained(model) 8 | tokenizer.bos_token_id, tokenizer.eos_token_id = 1, 2 9 | tokenizer.pad_token = tokenizer.eos_token 10 | return tokenizer 11 | ''' 12 | 13 | def load_sft_dataset(args): 14 | ''' 15 | tokenizer = get_tokenizer(args['model_path']) 16 | dataset_name = args['models'][args['model']]['stage1_train_dataset'] # SupervisedDataset, str 17 | data_path = args["data_path"] 18 | data = globals()[dataset_name](data_path, tokenizer, args['max_length']) #SupervisedDataset 19 | ''' 20 | data = SupervisedDataset(args['data_path'], args['image_root_path']) 21 | 22 | sampler = torch.utils.data.RandomSampler(data) 23 | world_size = torch.distributed.get_world_size() 24 | rank = torch.distributed.get_rank() 25 | batch_size = args['world_size'] * args['dschf'].config['train_micro_batch_size_per_gpu'] 26 | batch_sampler = DistributedBatchSampler( 27 | sampler, 28 | batch_size, 29 | True, 30 | rank, 31 | world_size 32 | ) 33 | iter_ = DataLoader( 34 | data, 35 | batch_sampler=batch_sampler, 36 | num_workers=1, 37 | collate_fn=data.collate, 38 | pin_memory=True 39 | ) 40 | return data, iter_, sampler 41 | -------------------------------------------------------------------------------- /pandagpt/code/datasets/samplers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """batch samplers that work with either random or sequential data samplers""" 16 | import math 17 | import os 18 | import sys 19 | 20 | import torch 21 | from torch.utils import data 22 | import numpy as np 23 | 24 | 25 | class RandomSampler(data.sampler.Sampler): 26 | r""" 27 | Based off of pytorch RandomSampler and DistributedSampler. Essentially a RandomSampler, 28 | but this class lets the user set an epoch like DistributedSampler 29 | Samples elements randomly. If without replacement, then sample from a shuffled dataset. 30 | If with replacement, then user can specify ``num_samples`` to draw. 31 | Arguments: 32 | data_source (Dataset): dataset to sample from 33 | num_samples (int): number of samples to draw, default=len(dataset) 34 | replacement (bool): samples are drawn with replacement if ``True``, default=False 35 | """ 36 | 37 | def __init__(self, data_source, replacement=False, num_samples=None): 38 | super(RandomSampler, self).__init__(data_source) 39 | self.data_source = data_source 40 | self.replacement = replacement 41 | self._num_samples = num_samples 42 | self.epoch = -1 43 | 44 | if self._num_samples is not None and replacement is False: 45 | raise ValueError("With replacement=False, num_samples should not be specified, " 46 | "since a random permute will be performed.") 47 | 48 | if not isinstance(self.num_samples, int) or self.num_samples <= 0: 49 | raise ValueError("num_samples should be a positive integer " 50 | "value, but got num_samples={}".format(self.num_samples)) 51 | if not isinstance(self.replacement, bool): 52 | raise ValueError("replacement should be a boolean value, but got " 53 | "replacement={}".format(self.replacement)) 54 | 55 | @property 56 | def num_samples(self): 57 | # dataset size might change at runtime 58 | if self._num_samples is None: 59 | return len(self.data_source) 60 | return self._num_samples 61 | 62 | def __iter__(self): 63 | n = len(self.data_source) 64 | g = torch.Generator() 65 | if self.epoch >= 0: 66 | g.manual_seed(self.epoch) 67 | if self.replacement: 68 | for _ in range(self.num_samples // 32): 69 | yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=g).tolist() 70 | yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, 71 | generator=g).tolist() 72 | else: 73 | yield from torch.randperm(n, generator=self.generator).tolist() 74 | 75 | def __len__(self): 76 | return self.num_samples 77 | 78 | def set_epoch(self, epoch): 79 | self.epoch = epoch 80 | 81 | 82 | class DistributedSequentialSampler(data.sampler.Sampler): 83 | def __init__(self, num_samples, train_iters, batch_size, rank=-1, world_size=2): 84 | super().__init__(num_samples) 85 | if rank == -1: 86 | rank = 0 87 | world_size = 1 88 | self.num_samples = num_samples 89 | self.rank = rank 90 | self.world_size = world_size 91 | self.start_iter = 0 92 | self.train_iters = train_iters 93 | self.batch_size = batch_size 94 | self.batch_bias = [i * (num_samples // batch_size) for i in range(batch_size)] 95 | 96 | def __iter__(self): 97 | for idx in range(self.start_iter, self.train_iters * 10): 98 | batch = [(idx + bias) % self.num_samples for bias in self.batch_bias] 99 | tbatch = self._batch(batch) 100 | yield tbatch 101 | 102 | def __len__(self): 103 | return self.train_iters 104 | 105 | def _batch(self, batch): 106 | """extracts samples only pertaining to this worker's batch""" 107 | start = self.rank*self.batch_size//self.world_size 108 | end = (self.rank+1)*self.batch_size//self.world_size 109 | return batch[start:end] 110 | 111 | 112 | class DistributedBatchSampler(data.sampler.BatchSampler): 113 | """ 114 | similar to normal implementation of distributed sampler, except implementation is at the 115 | batch sampler level, instead of just the sampler level. This allows wrapping of arbitrary 116 | data samplers (sequential, random, WeightedRandomSampler, etc.) with this batch sampler. 117 | """ 118 | def __init__(self, sampler, batch_size, drop_last, rank=-1, world_size=2, wrap_last=False, gradient_accumulation_steps=None): 119 | super(DistributedBatchSampler, self).__init__(sampler, batch_size, drop_last) 120 | if rank == -1: 121 | assert False, 'should not be here' 122 | self.rank = rank 123 | self.world_size = world_size 124 | self.sampler.wrap_around = 0 125 | self.wrap_around = 0 126 | self.wrap_last = wrap_last 127 | self.start_iter = 0 128 | self.effective_batch_size = batch_size if gradient_accumulation_steps is None else batch_size * gradient_accumulation_steps 129 | 130 | def __iter__(self): 131 | batch = [] 132 | i = 0 133 | for idx in self.data_iterator(self.sampler, wrap_around=False): 134 | batch.append(idx) 135 | if len(batch) == self.batch_size: 136 | tbatch = self._batch(batch) 137 | if i >= self.start_iter * self.effective_batch_size: 138 | yield tbatch 139 | self.start_iter = 0 140 | i += len(batch) 141 | batch = [] 142 | batch_len = len(batch) 143 | if batch_len > 0 and not self.drop_last: 144 | if self.wrap_last: 145 | self.sampler.wrap_around -= (self.batch_size) 146 | self.wrap_around += (len(batch)) 147 | self.wrap_around %= self.batch_size 148 | yield self._batch(batch) 149 | if self.wrap_last: 150 | self.sampler.wrap_around += self.batch_size 151 | 152 | def data_iterator(self, _iter, wrap_around=False): 153 | """iterates through data and handles wrap around""" 154 | for i, idx in enumerate(_iter): 155 | if i < self.wrap_around%self.batch_size: 156 | continue 157 | if wrap_around: 158 | self.wrap_around += 1 159 | self.wrap_around %= self.batch_size 160 | yield idx 161 | 162 | def _batch(self, batch): 163 | """extracts samples only pertaining to this worker's batch""" 164 | start = self.rank*self.batch_size//self.world_size 165 | end = (self.rank+1)*self.batch_size//self.world_size 166 | return batch[start:end] 167 | -------------------------------------------------------------------------------- /pandagpt/code/datasets/sft_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import copy 16 | import os 17 | import json 18 | from tqdm import tqdm 19 | import ipdb 20 | import random 21 | from torch.nn.utils.rnn import pad_sequence 22 | from dataclasses import dataclass, field 23 | from typing import Callable, Dict, Sequence 24 | 25 | import torch 26 | import torch.distributed as dist 27 | import transformers 28 | from torch.utils.data import Dataset 29 | from tqdm import tqdm 30 | 31 | class SupervisedDataset(Dataset): 32 | """Dataset for supervised fine-tuning.""" 33 | 34 | def __init__(self, data_path: str, image_root_path: str): 35 | super(SupervisedDataset, self).__init__() 36 | 37 | with open(data_path, 'r') as f: 38 | json_data = json.load(f) 39 | # for debug: 40 | #json_data = json_data[:100000] 41 | 42 | self.image_path_list, self.caption_list = [], [] 43 | for item in json_data: 44 | one_image_name, one_caption = item["image_name"], item["conversation"] 45 | # TODO: stage 2 dataset format is invalid 46 | if not one_image_name.endswith('.jpg'): 47 | one_image_name += '.jpg' 48 | one_image_path = image_root_path + '/{}'.format(one_image_name) 49 | self.image_path_list.append(one_image_path) 50 | self.caption_list.append(one_caption) 51 | print(f'[!] collect {len(self.image_path_list)} samples for training') 52 | 53 | def __len__(self): # number of instances 54 | return len(self.image_path_list) 55 | 56 | #def __getitem__(self, i) -> Dict[str, torch.Tensor]: # how to get item, 取一个样本 57 | def __getitem__(self, i): 58 | return dict(image_paths=self.image_path_list[i], output_texts=self.caption_list[i]) 59 | 60 | def collate(self, instances): 61 | image_paths, output_texts = tuple([instance[key] for instance in instances] for key in ("image_paths", "output_texts")) 62 | return dict( 63 | image_paths=image_paths, 64 | output_texts=output_texts 65 | ) 66 | -------------------------------------------------------------------------------- /pandagpt/code/dsconfig/openllama_peft_stage_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 64, 3 | "train_micro_batch_size_per_gpu": 1, 4 | "gradient_accumulation_steps": 8, 5 | "steps_per_print": 1, 6 | "gradient_clipping": 1.0, 7 | "zero_optimization": { 8 | "stage": 2, 9 | "offload_optimizer": { 10 | "device": "cpu" 11 | }, 12 | "contiguous_gradients": true, 13 | "allgather_bucket_size": 500000000, 14 | "allgather_partitions": true 15 | }, 16 | "fp16": { 17 | "enabled": true, 18 | "opt_level": "O2", 19 | "min_loss_scale": 1 20 | }, 21 | "bf16": { 22 | "enable": true 23 | }, 24 | "optimizer": { 25 | "type": "Adam", 26 | "params": { 27 | "lr": 0.0005, 28 | "betas": [ 29 | 0.9, 30 | 0.95 31 | ], 32 | "eps": 1e-8, 33 | "weight_decay": 0.001 34 | } 35 | }, 36 | "scheduler": { 37 | "type": "WarmupDecayLR", 38 | "params": { 39 | "warmup_min_lr": 0, 40 | "warmup_max_lr": 0.0005, 41 | "warmup_num_steps": 10, 42 | "total_num_steps": 10000 43 | } 44 | }, 45 | "activation_checkpointing": { 46 | "partition_activations": true, 47 | "cpu_checkpointing": true, 48 | "contiguous_memory_optimization": false, 49 | "number_checkpoints": null, 50 | "synchronize_checkpoint_boundary": false, 51 | "profile": false 52 | } 53 | 54 | } -------------------------------------------------------------------------------- /pandagpt/code/header.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import datetime 3 | import types 4 | import deepspeed 5 | from transformers.deepspeed import HfDeepSpeedConfig 6 | import transformers 7 | import numpy as np 8 | from collections import OrderedDict 9 | from torch.utils.data import Dataset, DataLoader 10 | from torch.nn.utils import clip_grad_norm_ 11 | from torch.cuda.amp import autocast, GradScaler 12 | from torch.nn import DataParallel 13 | from torch.optim import lr_scheduler 14 | import torch.optim as optim 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from tqdm import tqdm 18 | import os 19 | import re 20 | import math 21 | import random 22 | import json 23 | import time 24 | import logging 25 | from copy import deepcopy 26 | import ipdb 27 | import argparse 28 | import data 29 | from transformers import LlamaTokenizer, LlamaForCausalLM, LlamaConfig 30 | from torch.nn.utils.rnn import pad_sequence 31 | from peft import LoraConfig, TaskType, get_peft_model 32 | 33 | logging.getLogger("transformers").setLevel(logging.WARNING) 34 | logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR) 35 | os.environ['TOKENIZERS_PARALLELISM'] = 'false' 36 | -------------------------------------------------------------------------------- /pandagpt/code/model/ImageBind/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /pandagpt/code/model/ImageBind/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to ImageBind 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to Omnivore, you agree that your contributions will be licensed 31 | under the [LICENSE](LICENSE) file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /pandagpt/code/model/ImageBind/README.md: -------------------------------------------------------------------------------- 1 | # ImageBind: One Embedding Space To Bind Them All 2 | 3 | **[FAIR, Meta AI](https://ai.facebook.com/research/)** 4 | 5 | Rohit Girdhar*, 6 | Alaaeldin El-Nouby*, 7 | Zhuang Liu, 8 | Mannat Singh, 9 | Kalyan Vasudev Alwala, 10 | Armand Joulin, 11 | Ishan Misra* 12 | 13 | To appear at CVPR 2023 (*Highlighted paper*) 14 | 15 | [[`Paper`](https://facebookresearch.github.io/ImageBind/paper)] [[`Blog`](https://ai.facebook.com/blog/imagebind-six-modalities-binding-ai/)] [[`Demo`](https://imagebind.metademolab.com/)] [[`Supplementary Video`](https://dl.fbaipublicfiles.com/imagebind/imagebind_video.mp4)] [[`BibTex`](#citing-imagebind)] 16 | 17 | PyTorch implementation and pretrained models for ImageBind. For details, see the paper: **[ImageBind: One Embedding Space To Bind Them All](https://facebookresearch.github.io/ImageBind/paper)**. 18 | 19 | ImageBind learns a joint embedding across six different modalities - images, text, audio, depth, thermal, and IMU data. It enables novel emergent applications ‘out-of-the-box’ including cross-modal retrieval, composing modalities with arithmetic, cross-modal detection and generation. 20 | 21 | 22 | 23 | ![ImageBind](https://user-images.githubusercontent.com/8495451/236859695-ffa13364-3e39-4d99-a8da-fbfab17f9a6b.gif) 24 | 25 | ## ImageBind model 26 | 27 | Emergent zero-shot classification performance. 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 |
ModelIN1kK400NYU-DESCLLVIPEgo4Ddownload
imagebind_huge77.750.054.066.963.425.0checkpoint
52 | 53 | ## Usage 54 | 55 | Install pytorch 1.13+ and other 3rd party dependencies. 56 | 57 | ```shell 58 | conda create --name imagebind python=3.8 -y 59 | conda activate imagebind 60 | 61 | pip install -r requirements.txt 62 | ``` 63 | 64 | For windows users, you might need to install `soundfile` for reading/writing audio files. (Thanks @congyue1977) 65 | 66 | ``` 67 | pip install soundfile 68 | ``` 69 | 70 | 71 | Extract and compare features across modalities (e.g. Image, Text and Audio). 72 | 73 | ```python 74 | import data 75 | import torch 76 | from models import imagebind_model 77 | from models.imagebind_model import ModalityType 78 | 79 | text_list=["A dog.", "A car", "A bird"] 80 | image_paths=[".assets/dog_image.jpg", ".assets/car_image.jpg", ".assets/bird_image.jpg"] 81 | audio_paths=[".assets/dog_audio.wav", ".assets/car_audio.wav", ".assets/bird_audio.wav"] 82 | 83 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 84 | 85 | # Instantiate model 86 | model = imagebind_model.imagebind_huge(pretrained=True) 87 | model.eval() 88 | model.to(device) 89 | 90 | # Load data 91 | inputs = { 92 | ModalityType.TEXT: data.load_and_transform_text(text_list, device), 93 | ModalityType.VISION: data.load_and_transform_vision_data(image_paths, device), 94 | ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, device), 95 | } 96 | 97 | with torch.no_grad(): 98 | embeddings = model(inputs) 99 | 100 | print( 101 | "Vision x Text: ", 102 | torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T, dim=-1), 103 | ) 104 | print( 105 | "Audio x Text: ", 106 | torch.softmax(embeddings[ModalityType.AUDIO] @ embeddings[ModalityType.TEXT].T, dim=-1), 107 | ) 108 | print( 109 | "Vision x Audio: ", 110 | torch.softmax(embeddings[ModalityType.VISION] @ embeddings[ModalityType.AUDIO].T, dim=-1), 111 | ) 112 | 113 | # Expected output: 114 | # 115 | # Vision x Text: 116 | # tensor([[9.9761e-01, 2.3694e-03, 1.8612e-05], 117 | # [3.3836e-05, 9.9994e-01, 2.4118e-05], 118 | # [4.7997e-05, 1.3496e-02, 9.8646e-01]]) 119 | # 120 | # Audio x Text: 121 | # tensor([[1., 0., 0.], 122 | # [0., 1., 0.], 123 | # [0., 0., 1.]]) 124 | # 125 | # Vision x Audio: 126 | # tensor([[0.8070, 0.1088, 0.0842], 127 | # [0.1036, 0.7884, 0.1079], 128 | # [0.0018, 0.0022, 0.9960]]) 129 | 130 | ``` 131 | 132 | ## Model card 133 | Please see the [model card](model_card.md) for details. 134 | 135 | ## License 136 | 137 | ImageBind code and model weights are released under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for additional details. 138 | 139 | ## Contributing 140 | 141 | See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). 142 | 143 | ## Citing ImageBind 144 | 145 | If you find this repository useful, please consider giving a star :star: and citation 146 | 147 | ``` 148 | @inproceedings{girdhar2023imagebind, 149 | title={ImageBind: One Embedding Space To Bind Them All}, 150 | author={Girdhar, Rohit and El-Nouby, Alaaeldin and Liu, Zhuang 151 | and Singh, Mannat and Alwala, Kalyan Vasudev and Joulin, Armand and Misra, Ishan}, 152 | booktitle={CVPR}, 153 | year={2023} 154 | } 155 | ``` 156 | -------------------------------------------------------------------------------- /pandagpt/code/model/ImageBind/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import imagebind_model 2 | from .models.imagebind_model import ModalityType 3 | -------------------------------------------------------------------------------- /pandagpt/code/model/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/model/ImageBind/bpe/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /pandagpt/code/model/ImageBind/model_card.md: -------------------------------------------------------------------------------- 1 | # Model Card for ImageBind 2 | 3 | Multimodal joint embedding model for image/video, text, audio, depth, IMU, and thermal images. 4 | Input any of the six modalities and get the same sized embedding that can be used for cross-modal and multimodal tasks. 5 | 6 | # Model Details 7 | 8 | ## Model Description 9 | 10 | 11 | Multimodal joint embedding model for image/video, text, audio, depth, IMU, and thermal images 12 | 13 | - **Developed by:** Meta AI 14 | - **Model type:** Multimodal model 15 | - **Language(s) (NLP):** en 16 | - **License:** CC BY-NC-SA 4.0 17 | - **Resources for more information:** 18 | - [GitHub Repo](https://github.com/facebookresearch/ImageBind) 19 | 20 | 21 | # Uses 22 | 23 | 24 | This model is intended only for research purposes. It provides a joint embedding space for different modalities -- image/video, text, audio, depth, IMU and thermal images. 25 | We hope that these joint embeddings can be used for a variety of different cross-modal research, e.g., cross-modal retrieval and combining embeddings from different modalities. 26 | 27 | ## Out-of-Scope Use 28 | 29 | 30 | 31 | 32 | This model is *NOT* intended to be used in any real world application -- commercial or otherwise. 33 | It may produce harmful associations with different inputs. 34 | The model needs to be investigated and likely re-trained on specific data for any such application. 35 | The model is expected to work better on web-based visual data since it was trained on such data. 36 | The text encoder is likely to work only on English language text because of the underlying training datasets. 37 | 38 | # Bias, Risks, and Limitations 39 | 40 | 41 | Open-domain joint embedding models are prone to producing specific biases, e.g., study from [CLIP](https://github.com/openai/CLIP/blob/main/model-card.md#bias-and-fairness). 42 | Since our model uses such models as initialization, it will exhibit such biases too. 43 | Moreover, for learning joint embeddings for other modalities such as audio, thermal, depth, and IMU we leverage datasets that are relatively small. These joint embeddings are thus limited to the concepts present in the datasets. For example, the thermal datasets we used are limited to outdoor street scenes, while the depth datasets are limited to indoor scenes. 44 | 45 | 46 | 47 | # Training Details 48 | 49 | ## Training Data 50 | 51 | 52 | 53 | ImageBind uses image-paired data for training -- (image, X) where X is one of text, audio, depth, IMU or thermal data. 54 | In particular, we initialize and freeze the image and text encoders using an OpenCLIP ViT-H encoder. 55 | We train audio embeddings using Audioset, depth embeddings using the SUN RGB-D dataset, IMU using the Ego4D dataset and thermal embeddings using the LLVIP dataset. 56 | We provide the exact training data details in the paper. 57 | 58 | 59 | ## Training Procedure 60 | 61 | 62 | Please refer to the research paper and github repo for exact details on this. 63 | 64 | # Evaluation 65 | 66 | ## Testing Data, Factors & Metrics 67 | 68 | We evaluate the model on a variety of different classification benchmarks for each modality. 69 | The evaluation details are presented in the paper. 70 | The models performance is measured using standard classification metrics such as accuracy and mAP. 71 | 72 | # Citation 73 | 74 | 75 | 76 | **BibTeX:** 77 | ``` 78 | @inproceedings{girdhar2023imagebind, 79 | title={ImageBind: One Embedding Space To Bind Them All}, 80 | author={Girdhar, Rohit and El-Nouby, Alaaeldin and Liu, Zhuang 81 | and Singh, Mannat and Alwala, Kalyan Vasudev and Joulin, Armand and Misra, Ishan}, 82 | booktitle={CVPR}, 83 | year={2023} 84 | } 85 | ``` 86 | 87 | 88 | # Model Card Contact 89 | 90 | Please reach out to the authors at: rgirdhar@meta.com imisra@meta.com alaaelnouby@gmail.com 91 | 92 | # How to Get Started with the Model 93 | 94 | Our github repo provides a simple example to extract embeddings from images, audio etc. 95 | -------------------------------------------------------------------------------- /pandagpt/code/model/ImageBind/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/pandagpt/code/model/ImageBind/models/__init__.py -------------------------------------------------------------------------------- /pandagpt/code/model/ImageBind/models/helpers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Portions Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # All rights reserved. 4 | 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import math 9 | 10 | import einops 11 | import numpy as np 12 | import torch 13 | 14 | import torch.nn as nn 15 | 16 | 17 | class Normalize(nn.Module): 18 | def __init__(self, dim: int) -> None: 19 | super().__init__() 20 | self.dim = dim 21 | 22 | def forward(self, x): 23 | return torch.nn.functional.normalize(x, dim=self.dim, p=2) 24 | 25 | 26 | class LearnableLogitScaling(nn.Module): 27 | def __init__( 28 | self, 29 | logit_scale_init: float = 1 / 0.07, 30 | learnable: bool = True, 31 | max_logit_scale: float = 100, 32 | ) -> None: 33 | super().__init__() 34 | self.max_logit_scale = max_logit_scale 35 | self.logit_scale_init = logit_scale_init 36 | self.learnable = learnable 37 | log_logit_scale = torch.ones([]) * np.log(self.logit_scale_init) 38 | if learnable: 39 | self.log_logit_scale = nn.Parameter(log_logit_scale) 40 | else: 41 | self.register_buffer("log_logit_scale", log_logit_scale) 42 | 43 | def forward(self, x): 44 | return torch.clip(self.log_logit_scale.exp(), max=self.max_logit_scale) * x 45 | 46 | def extra_repr(self): 47 | st = f"logit_scale_init={self.logit_scale_init},learnable={self.learnable}, max_logit_scale={self.max_logit_scale}" 48 | return st 49 | 50 | 51 | class EinOpsRearrange(nn.Module): 52 | def __init__(self, rearrange_expr: str, **kwargs) -> None: 53 | super().__init__() 54 | self.rearrange_expr = rearrange_expr 55 | self.kwargs = kwargs 56 | 57 | def forward(self, x): 58 | assert isinstance(x, torch.Tensor) 59 | return einops.rearrange(x, self.rearrange_expr, **self.kwargs) 60 | 61 | 62 | class VerboseNNModule(nn.Module): 63 | """ 64 | Wrapper around nn.Module that prints registered buffers and parameter names. 65 | """ 66 | 67 | @staticmethod 68 | def get_readable_tensor_repr(name: str, tensor: torch.Tensor) -> str: 69 | st = ( 70 | "(" 71 | + name 72 | + "): " 73 | + "tensor(" 74 | + str(tuple(tensor[1].shape)) 75 | + ", requires_grad=" 76 | + str(tensor[1].requires_grad) 77 | + ")\n" 78 | ) 79 | return st 80 | 81 | def extra_repr(self) -> str: 82 | named_modules = set() 83 | for p in self.named_modules(): 84 | named_modules.update([p[0]]) 85 | named_modules = list(named_modules) 86 | 87 | string_repr = "" 88 | for p in self.named_parameters(): 89 | name = p[0].split(".")[0] 90 | if name not in named_modules: 91 | string_repr += self.get_readable_tensor_repr(name, p) 92 | 93 | for p in self.named_buffers(): 94 | name = p[0].split(".")[0] 95 | string_repr += self.get_readable_tensor_repr(name, p) 96 | 97 | return string_repr 98 | 99 | 100 | def cast_if_src_dtype( 101 | tensor: torch.Tensor, src_dtype: torch.dtype, tgt_dtype: torch.dtype 102 | ): 103 | updated = False 104 | if tensor.dtype == src_dtype: 105 | tensor = tensor.to(dtype=tgt_dtype) 106 | updated = True 107 | return tensor, updated 108 | 109 | 110 | class QuickGELU(nn.Module): 111 | # From https://github.com/openai/CLIP/blob/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1/clip/model.py#L166 112 | def forward(self, x: torch.Tensor): 113 | return x * torch.sigmoid(1.702 * x) 114 | 115 | 116 | class SelectElement(nn.Module): 117 | def __init__(self, index) -> None: 118 | super().__init__() 119 | self.index = index 120 | 121 | def forward(self, x): 122 | assert x.ndim >= 3 123 | return x[:, self.index, ...] 124 | 125 | 126 | class SelectEOSAndProject(nn.Module): 127 | """ 128 | Text Pooling used in OpenCLIP 129 | """ 130 | 131 | def __init__(self, proj: nn.Module) -> None: 132 | super().__init__() 133 | self.proj = proj 134 | 135 | def forward(self, x, seq_len): 136 | assert x.ndim == 3 137 | # x is of shape B x L x D 138 | # take features from the eot embedding (eot_token is the highest number in each sequence) 139 | x = x[torch.arange(x.shape[0]), seq_len] 140 | x = self.proj(x) 141 | return x 142 | -------------------------------------------------------------------------------- /pandagpt/code/model/ImageBind/requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu113 2 | torchvision==0.14.0 3 | torchaudio==0.13.0 4 | pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d 5 | timm==0.6.7 6 | ftfy 7 | regex 8 | einops 9 | fvcore 10 | decord==0.6.0 11 | -------------------------------------------------------------------------------- /pandagpt/code/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .agent import DeepSpeedAgent 2 | from .openllama import OpenLLAMAPEFTModel 3 | 4 | def load_model(args): 5 | agent_name = args['models'][args['model']]['agent_name'] 6 | model_name = args['models'][args['model']]['model_name'] 7 | model = globals()[model_name](**args) 8 | agent = globals()[agent_name](model, args) 9 | return agent 10 | -------------------------------------------------------------------------------- /pandagpt/code/model/agent.py: -------------------------------------------------------------------------------- 1 | from header import * 2 | 3 | class DeepSpeedAgent: 4 | 5 | def __init__(self, model, args): 6 | super(DeepSpeedAgent, self).__init__() 7 | self.args = args 8 | self.model = model 9 | if args['stage'] == 2: 10 | self.load_stage_1_parameters(args["delta_ckpt_path"]) 11 | print(f'[!] load stage 1 checkpoint from {args["delta_ckpt_path"]}') 12 | 13 | # load config parameters of deepspeed 14 | ds_params = json.load(open(self.args['ds_config_path'])) 15 | ds_params['scheduler']['params']['total_num_steps'] = self.args['total_steps'] 16 | ds_params['scheduler']['params']['warmup_num_steps'] = max(10, int(self.args['total_steps'] * self.args['warmup_rate'])) 17 | self.ds_engine, self.optimizer, _ , _ = deepspeed.initialize( 18 | model=self.model, 19 | model_parameters=self.model.parameters(), 20 | config_params=ds_params, 21 | dist_init_required=True, 22 | args=types.SimpleNamespace(**args) 23 | ) 24 | 25 | @torch.no_grad() 26 | def predict(self, batch): 27 | self.model.eval() 28 | string = self.model.generate_one_sample(batch) 29 | return string 30 | 31 | def train_model(self, batch, current_step=0, pbar=None): 32 | self.ds_engine.module.train() 33 | loss, mle_acc = self.ds_engine(batch) 34 | 35 | self.ds_engine.backward(loss) 36 | self.ds_engine.step() 37 | pbar.set_description(f'[!] loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}') 38 | pbar.update(1) 39 | if self.args['local_rank'] == 0 and self.args['log_path'] and current_step % self.args['logging_step'] == 0: 40 | elapsed = pbar.format_dict['elapsed'] 41 | rate = pbar.format_dict['rate'] 42 | remaining = (pbar.total - pbar.n) / rate if rate and pbar.total else 0 43 | remaining = str(datetime.timedelta(seconds=remaining)) 44 | logging.info(f'[!] progress: {round(pbar.n/pbar.total, 5)}; remaining time: {remaining}; loss: {round(loss.item(), 4)}; token_acc: {round(mle_acc*100, 2)}') 45 | 46 | mle_acc *= 100 47 | return mle_acc 48 | 49 | def save_model(self, path, current_step): 50 | # only save trainable model parameters 51 | param_grad_dic = { 52 | k: v.requires_grad for (k, v) in self.ds_engine.module.named_parameters() 53 | } 54 | state_dict = self.ds_engine.module.state_dict() 55 | checkpoint = OrderedDict() 56 | for k, v in self.ds_engine.module.named_parameters(): 57 | if v.requires_grad: 58 | checkpoint[k] = v 59 | torch.save(checkpoint, f'{path}/pytorch_model.pt') 60 | # save tokenizer 61 | self.model.llama_tokenizer.save_pretrained(path) 62 | # save configuration 63 | self.model.llama_model.config.save_pretrained(path) 64 | print(f'[!] save model into {path}') 65 | 66 | def load_stage_1_parameters(self, path): 67 | delta_ckpt = torch.load(path, map_location=torch.device('cpu')) 68 | self.model.load_state_dict(delta_ckpt, strict=False) 69 | -------------------------------------------------------------------------------- /pandagpt/code/scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_addr 127.0.0.1 --master_port 28457 train_sft.py \ 4 | --model openllama_peft \ 5 | --stage 1\ 6 | --data_path ../data/pandagpt4_visual_instruction_data.json\ 7 | --image_root_path ../data/images/\ 8 | --imagebind_ckpt_path ../pretrained_ckpt/imagebind_ckpt/\ 9 | --vicuna_ckpt_path ../pretrained_ckpt/vicuna_ckpt/13b_v0/\ 10 | --max_tgt_len 400\ 11 | --save_path ./ckpt/pandagpt_13b_v0_peft/\ 12 | --log_path ./ckpt/pandagpt_13b_v0_peft/log_rest/ 13 | -------------------------------------------------------------------------------- /pandagpt/code/train_sft.py: -------------------------------------------------------------------------------- 1 | from header import * 2 | from datasets import * 3 | from model import * 4 | from config import * 5 | 6 | def parser_args(): 7 | parser = argparse.ArgumentParser(description='train parameters') 8 | parser.add_argument('--model', type=str) 9 | parser.add_argument('--data_path', type=str) 10 | parser.add_argument('--local_rank', default=0, type=int) 11 | parser.add_argument('--save_path', type=str) 12 | parser.add_argument('--log_path', type=str) 13 | # model configurations 14 | parser.add_argument('--image_root_path', type=str) # the directory that stores all images 15 | parser.add_argument('--imagebind_ckpt_path', type=str) # the path that stores the imagebind checkpoint 16 | parser.add_argument('--vicuna_ckpt_path', type=str) # the path that stores the vicuna checkpoint 17 | parser.add_argument('--delta_ckpt_path', type=str) # the delta parameters trained in stage 1 18 | parser.add_argument('--max_tgt_len', type=int) # the maximum sequence length 19 | parser.add_argument('--stage', type=int) # the maximum sequence length 20 | return parser.parse_args() 21 | 22 | def initialize_distributed(args): 23 | args['master_ip'] = os.getenv('MASTER_ADDR', 'localhost') 24 | args['master_port'] = os.getenv('MASTER_PORT', '6000') 25 | args['world_size'] = int(os.getenv('WORLD_SIZE', '1')) 26 | args['local_rank'] = int(os.getenv('RANK', '0')) % torch.cuda.device_count() 27 | device = args['local_rank'] % torch.cuda.device_count() 28 | torch.cuda.set_device(device) 29 | deepspeed.init_distributed(dist_backend='nccl') 30 | 31 | def set_random_seed(seed): 32 | if seed is not None and seed > 0: 33 | random.seed(seed) 34 | np.random.seed(seed) 35 | torch.manual_seed(seed) 36 | torch.random.manual_seed(seed) 37 | torch.cuda.manual_seed(seed) 38 | torch.cuda.manual_seed_all(seed) 39 | 40 | def config_env(args): 41 | args['root_dir'] = '../' 42 | args['mode'] = 'train' 43 | config = load_config(args) 44 | args.update(config) 45 | initialize_distributed(args) 46 | set_random_seed(args['seed']) 47 | 48 | def build_directory(path): 49 | if os.path.exists(path): 50 | pass 51 | else: # recursively construct directory 52 | os.makedirs(path, exist_ok=True) 53 | 54 | def main(**args): 55 | config_env(args) 56 | args['ds_config_path'] = f'dsconfig/{args["model"]}_stage_{args["stage"]}.json' 57 | dschf = HfDeepSpeedConfig(args['ds_config_path']) 58 | args['dschf'] = dschf 59 | 60 | build_directory(args['save_path']) 61 | build_directory(args['log_path']) 62 | 63 | if args['log_path']: 64 | logging.basicConfig( 65 | format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s', 66 | level=logging.DEBUG, 67 | filename=f'{args["log_path"]}/train_{time.asctime()}.log', 68 | filemode='w' 69 | ) 70 | 71 | train_data, train_iter, sampler = load_sft_dataset(args) 72 | 73 | length = args['epochs'] * len(train_data) // args['world_size'] // dschf.config['train_micro_batch_size_per_gpu'] 74 | total_steps = args['epochs'] * len(train_data) // dschf.config['train_batch_size'] 75 | args['total_steps'] = total_steps 76 | agent = load_model(args) 77 | torch.distributed.barrier() 78 | 79 | # begin to train 80 | pbar = tqdm(total=length) # maximum total number 81 | current_step = 0 82 | for epoch_i in tqdm(range(args['epochs'])): 83 | for batch in train_iter: 84 | agent.train_model( 85 | batch, 86 | current_step=current_step, 87 | pbar=pbar 88 | ) 89 | current_step += 1 90 | # save at the end of the training 91 | torch.distributed.barrier() 92 | agent.save_model(args['save_path'], 0) 93 | 94 | if __name__ == "__main__": 95 | args = parser_args() 96 | args = vars(args) 97 | main(**args) 98 | -------------------------------------------------------------------------------- /pandagpt/data/empty.txt: -------------------------------------------------------------------------------- 1 | empty text -------------------------------------------------------------------------------- /pandagpt/pretrained_ckpt/README.md: -------------------------------------------------------------------------------- 1 | # 1. Prepare Vicuna Checkpoint: 2 | 3 | The language decoder of PandaGPT is based on Vicuna version 0. Given the distribution license of LLaMA, you need to restore the weights of Vicuna manually. To restore the weights, please follow the instructions below. In the following, we showcase how to restore the 7B version of Vicuna v0. To obtain the 13B version of Vicuna, you can take similar procedures. 4 | 5 | ## 1.1. Obtain LLaMA Weights: 6 | * Request the weights of LLaMA from Meta using [this form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform). 7 | * After obtaining the weights of a specific LLaMA (e.g. 7B, 13B), following [instructions](https://huggingface.co/docs/transformers/main/model_doc/llama) provided by Huggingface to convert it into Huggingface format. 8 | 9 | > **** After conversion, the directory should look like: 10 | 11 | . 12 | └── ./{path_to_llama_weights}/ 13 | ├── config.json 14 | ├── generation_config.json 15 | ├── pytorch_model-00001-of-00002.bin 16 | ├── pytorch_model-00002-of-00002.bin 17 | ├── pytorch_model.bin.index.json 18 | ├── special_tokens_map.json 19 | ├── tokenizer.model 20 | └── tokenizer_config.json 21 | 22 | `{path_to_llama_weights}` is where you store the checkpoints. 23 | 24 | 25 | ## 1.2. Obtain the Delta Weights of Vicuna: 26 | 27 | Then, you should download the delta weights of Vicuna provided by the original authors. You can find the corresponding links to 7B/13B Vicuna models in the table below. 28 | 29 | |**Model Size**|**Delta Weights Address**|**Version**| 30 | |:-------------:|:-------------:|:-------------:| 31 | |7B|[[Link]](https://huggingface.co/lmsys/vicuna-7b-delta-v0)|0| 32 | |13B|[[Link]](https://huggingface.co/lmsys/vicuna-13b-delta-v0)|0| 33 | 34 | 35 | 36 | > **** After conversion, the directory should look like: 37 | 38 | . 39 | └── ./{path_to_delta_vicuna_weights}/ 40 | ├── config.json 41 | ├── generation_config.json 42 | ├── pytorch_model-00001-of-00002.bin 43 | ├── pytorch_model-00002-of-00002.bin 44 | ├── pytorch_model.bin.index.json 45 | ├── special_tokens_map.json 46 | ├── tokenizer.model 47 | └── tokenizer_config.json 48 | 49 | `{path_to_delta_vicuna_weights}` is where you store the delta weights of Vicuna. 50 | 51 | ## 1.3. Combine the Weights: 52 | 53 | When the two sets of weights are ready, you can combine them using tools from the Vicuna team. 54 | 55 | First, install the required library. 56 | ```yaml 57 | pip install git+https://github.com/lm-sys/FastChat.git@v0.1.10 58 | ``` 59 | 60 | Then, run the following command. 61 | ```yaml 62 | python -m fastchat.model.apply_delta --base {path_to_llama_weights} --target ./vicuna_ckpt/7b_v0/ --delta {path_to_delta_vicuna_weights} 63 | ``` 64 | 65 | > **** Now, the final weights are ready as: 66 | 67 | . 68 | └── ./vicuna_ckpt/7b_v0/ 69 | ├── config.json 70 | ├── generation_config.json 71 | ├── pytorch_model-00001-of-00002.bin 72 | ├── pytorch_model-00002-of-00002.bin 73 | ├── pytorch_model.bin.index.json 74 | ├── special_tokens_map.json 75 | ├── tokenizer.model 76 | └── tokenizer_config.json 77 | 78 | 79 | -------------------------------------------------------------------------------- /pandagpt/pretrained_ckpt/imagebind_ckpt/empty.txt: -------------------------------------------------------------------------------- 1 | empty placeholder -------------------------------------------------------------------------------- /pandagpt/pretrained_ckpt/pandagpt_ckpt/13b/empty.txt: -------------------------------------------------------------------------------- 1 | empty placeholder -------------------------------------------------------------------------------- /pandagpt/pretrained_ckpt/pandagpt_ckpt/7b/empty.txt: -------------------------------------------------------------------------------- 1 | empty placeholder -------------------------------------------------------------------------------- /pandagpt/pretrained_ckpt/vicuna_ckpt/13b_v0/empty.txt: -------------------------------------------------------------------------------- 1 | empty placeholder -------------------------------------------------------------------------------- /pandagpt/pretrained_ckpt/vicuna_ckpt/7b_v0/empty.txt: -------------------------------------------------------------------------------- 1 | empty placeholder -------------------------------------------------------------------------------- /pandagpt/requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.6.7 2 | deepspeed==0.9.2 3 | data 4 | einops==0.6.1 5 | ftfy==6.1.1 6 | iopath==0.1.10 7 | ipdb==0.13.13 8 | numpy==1.24.3 9 | peft==0.3.0 10 | Pillow==9.5.0 11 | PyYAML==6.0 12 | regex==2022.10.31 13 | torchvision==0.14.1 14 | torchaudio==0.13.1 15 | pytorchvideo 16 | fvcore 17 | decord==0.6.0 18 | tqdm==4.64.1 19 | transformers==4.29.1 20 | jupyter 21 | sentencepiece -------------------------------------------------------------------------------- /result_audios/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_audios/.gitkeep -------------------------------------------------------------------------------- /result_audios/bird_malicious.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_audios/bird_malicious.pt -------------------------------------------------------------------------------- /result_audios/bird_malicious.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_audios/bird_malicious.wav -------------------------------------------------------------------------------- /result_audios/panda-italy-baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_audios/panda-italy-baseline.png -------------------------------------------------------------------------------- /result_audios/panda-italy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_audios/panda-italy.png -------------------------------------------------------------------------------- /result_images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/.DS_Store -------------------------------------------------------------------------------- /result_images/llava-baby-baseline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava-baby-baseline.png -------------------------------------------------------------------------------- /result_images/llava-crying-baby.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava-crying-baby.png -------------------------------------------------------------------------------- /result_images/llava-pirate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava-pirate.png -------------------------------------------------------------------------------- /result_images/llava-potter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava-potter.png -------------------------------------------------------------------------------- /result_images/llava/harrypotter_partial.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava/harrypotter_partial.png -------------------------------------------------------------------------------- /result_images/llava/harrypotter_partial.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava/harrypotter_partial.pt -------------------------------------------------------------------------------- /result_images/llava/perturb_full_X.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava/perturb_full_X.jpg -------------------------------------------------------------------------------- /result_images/llava/perturb_full_X.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava/perturb_full_X.pt -------------------------------------------------------------------------------- /result_images/llava/perturb_partial_X.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava/perturb_partial_X.jpg -------------------------------------------------------------------------------- /result_images/llava/perturb_partial_X.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/llava/perturb_partial_X.pt -------------------------------------------------------------------------------- /result_images/panda-audio-phishing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/panda-audio-phishing.png -------------------------------------------------------------------------------- /result_images/pandagpt/panda_cow_full.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/pandagpt/panda_cow_full.jpg -------------------------------------------------------------------------------- /result_images/pandagpt/panda_cow_full.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/pandagpt/panda_cow_full.pt -------------------------------------------------------------------------------- /result_images/pandagpt/panda_cow_partial.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/pandagpt/panda_cow_partial.jpg -------------------------------------------------------------------------------- /result_images/pandagpt/panda_cow_partial.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ebagdasa/multimodal_injection/be8f118c06d7783e4f458046947d1271e7b435d6/result_images/pandagpt/panda_cow_partial.pt --------------------------------------------------------------------------------