├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── config.py ├── dockerignore.txt ├── editorconfig.txt ├── gitattributes.txt ├── gitignore.txt ├── images ├── Framework.png ├── HealthGPT.png ├── chatUI.jpg └── intro.png ├── llava ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── constants.cpython-310.pyc │ ├── conversation.cpython-310.pyc │ └── mm_utils.cpython-310.pyc ├── constants.py ├── conversation.py ├── demo │ ├── __init__.py │ ├── __pycache__ │ │ └── utils.cpython-310.pyc │ ├── com_infer.py │ ├── com_infer.sh │ ├── com_infer_phi4.py │ ├── com_infer_phi4.sh │ ├── com_infer_qwen2_5.py │ ├── com_infer_qwen2_5.sh │ ├── gen_infer.py │ ├── gen_infer.sh │ └── utils.py ├── eval │ ├── eval_gpt_review.py │ ├── eval_gpt_review_bench.py │ ├── eval_gpt_review_visual.py │ ├── eval_pope.py │ ├── eval_science_qa.py │ ├── eval_science_qa_gpt4.py │ ├── eval_science_qa_gpt4_requery.py │ ├── eval_textvqa.py │ ├── generate_webpage_data_from_table.py │ ├── m4c_evaluator.py │ ├── model_qa.py │ ├── model_vqa.py │ ├── model_vqa_loader.py │ ├── model_vqa_mmbench.py │ ├── model_vqa_science.py │ ├── qa_baseline_gpt35.py │ ├── run_llava.py │ ├── summarize_gpt_review.py │ ├── table │ │ ├── answer │ │ │ ├── answer_alpaca-13b.jsonl │ │ │ ├── answer_bard.jsonl │ │ │ ├── answer_gpt35.jsonl │ │ │ ├── answer_llama-13b.jsonl │ │ │ └── answer_vicuna-13b.jsonl │ │ ├── caps_boxes_coco2014_val_80.jsonl │ │ ├── model.jsonl │ │ ├── prompt.jsonl │ │ ├── question.jsonl │ │ ├── results │ │ │ ├── test_sqa_llava_13b_v0.json │ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json │ │ ├── review │ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl │ │ │ ├── review_bard_vicuna-13b.jsonl │ │ │ ├── review_gpt35_vicuna-13b.jsonl │ │ │ └── review_llama-13b_vicuna-13b.jsonl │ │ ├── reviewer.jsonl │ │ └── rule.json │ └── webpage │ │ ├── figures │ │ ├── alpaca.png │ │ ├── bard.jpg │ │ ├── chatgpt.svg │ │ ├── llama.jpg │ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg │ │ └── vicuna.jpeg │ │ ├── index.html │ │ ├── script.js │ │ └── styles.css ├── mm_utils.py ├── model │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ └── llava_arch.cpython-310.pyc │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── __pycache__ │ │ │ ├── llava_llama.cpython-310.pyc │ │ │ ├── llava_mistral.cpython-310.pyc │ │ │ ├── llava_mpt.cpython-310.pyc │ │ │ └── llava_phi3.cpython-310.pyc │ │ ├── llava_llama.py │ │ ├── llava_mistral.py │ │ ├── llava_mpt.py │ │ ├── llava_phi3.py │ │ └── llava_qwen.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── __pycache__ │ │ │ ├── builder.cpython-310.pyc │ │ │ └── clip_encoder.cpython-310.pyc │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ ├── __pycache__ │ │ │ └── builder.cpython-310.pyc │ │ └── builder.py │ └── utils.py ├── peft │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── mapping.cpython-310.pyc │ │ ├── mapping.cpython-39.pyc │ │ ├── peft_model.cpython-310.pyc │ │ └── peft_model.cpython-39.pyc │ ├── mapping.py │ ├── peft_model.py │ ├── tuners │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── lora.cpython-310.pyc │ │ │ ├── lora.cpython-310.pyc.139831930933664 │ │ │ ├── lora.cpython-310.pyc.140023258402944 │ │ │ ├── lora.cpython-310.pyc.140104265109632 │ │ │ ├── lora.cpython-310.pyc.140108281172096 │ │ │ ├── lora.cpython-310.pyc.140160632344704 │ │ │ ├── lora.cpython-310.pyc.140480966316160 │ │ │ ├── lora.cpython-310.pyc.140577961184688 │ │ │ ├── lora.cpython-39.pyc │ │ │ ├── lora_moe.cpython-310.pyc │ │ │ ├── p_tuning.cpython-310.pyc │ │ │ ├── p_tuning.cpython-39.pyc │ │ │ ├── prefix_tuning.cpython-310.pyc │ │ │ ├── prefix_tuning.cpython-39.pyc │ │ │ ├── prompt_tuning.cpython-310.pyc │ │ │ └── prompt_tuning.cpython-39.pyc │ │ ├── lora.py │ │ ├── p_tuning.py │ │ ├── prefix_tuning.py │ │ └── prompt_tuning.py │ └── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── adapters_utils.cpython-310.pyc │ │ ├── adapters_utils.cpython-39.pyc │ │ ├── config.cpython-310.pyc │ │ ├── config.cpython-39.pyc │ │ ├── other.cpython-310.pyc │ │ ├── other.cpython-39.pyc │ │ ├── save_and_load.cpython-310.pyc │ │ └── save_and_load.cpython-39.pyc │ │ ├── adapters_utils.py │ │ ├── config.py │ │ ├── other.py │ │ └── save_and_load.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ ├── sglang_worker.py │ └── test_message.py └── utils.py ├── model.py ├── requirements.txt ├── requirements_qwen_2_5.txt ├── scripts ├── convert_gqa_for_eval.py ├── convert_mmbench_for_submission.py ├── convert_mmvet_for_eval.py ├── convert_seed_for_submission.py ├── convert_sqa_to_llava.py ├── convert_sqa_to_llava_base_prompt.py ├── convert_vizwiz_for_submission.py ├── convert_vqav2_for_submission.py ├── extract_mm_projector.py ├── finetune.sh ├── finetune_full_schedule.sh ├── finetune_lora.sh ├── finetune_qlora.sh ├── finetune_sqa.sh ├── merge_lora_weights.py ├── pretrain.sh ├── pretrain_xformers.sh ├── sqa_eval_batch.sh ├── sqa_eval_gather.sh ├── upload_pypi.sh ├── v1_5 │ ├── eval │ │ ├── gqa.sh │ │ ├── llavabench.sh │ │ ├── mmbench.sh │ │ ├── mmbench_cn.sh │ │ ├── mme.sh │ │ ├── mmvet.sh │ │ ├── pope.sh │ │ ├── qbench.sh │ │ ├── qbench_zh.sh │ │ ├── seed.sh │ │ ├── sqa.sh │ │ ├── textvqa.sh │ │ ├── vizwiz.sh │ │ └── vqav2.sh │ ├── finetune.sh │ ├── finetune_lora.sh │ ├── finetune_task.sh │ ├── finetune_task_lora.sh │ └── pretrain.sh ├── zero2.json ├── zero3.json └── zero3_offload.json └── taming_transformers ├── License.txt ├── ckpt └── model.yaml ├── environment.yaml ├── idx2img.py ├── main.py ├── scripts ├── extract_depth.py ├── extract_segmentation.py ├── extract_submodel.py ├── make_samples.py ├── make_scene_samples.py ├── sample_conditional.py └── sample_fast.py ├── setup.py └── taming ├── data ├── ade20k.py ├── annotated_objects_coco.py ├── annotated_objects_dataset.py ├── annotated_objects_open_images.py ├── base.py ├── coco.py ├── conditional_builder │ ├── objects_bbox.py │ ├── objects_center_points.py │ └── utils.py ├── custom.py ├── faceshq.py ├── helper_types.py ├── image_transforms.py ├── imagenet.py ├── open_images_helper.py ├── sflckr.py └── utils.py ├── lr_scheduler.py ├── models ├── __pycache__ │ └── vqgan.cpython-310.pyc ├── cond_transformer.py ├── dummy_cond_stage.py └── vqgan.py ├── modules ├── __pycache__ │ └── util.cpython-310.pyc ├── diffusionmodules │ ├── __pycache__ │ │ └── model.cpython-310.pyc │ └── model.py ├── discriminator │ ├── __pycache__ │ │ └── model.cpython-310.pyc │ └── model.py ├── losses │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── lpips.cpython-310.pyc │ │ └── vqperceptual.cpython-310.pyc │ ├── lpips.py │ ├── segmentation.py │ └── vqperceptual.py ├── misc │ └── coord.py ├── transformer │ ├── mingpt.py │ └── permuter.py ├── util.py └── vqvae │ ├── __pycache__ │ └── quantize.cpython-310.pyc │ └── quantize.py └── util.py /app.py: -------------------------------------------------------------------------------- 1 | import traceback 2 | 3 | from model import HealthGPT, HealthGPT_Agent 4 | from config import HealthGPTConfig_M3_COM, HealthGPTConfig_M3_GEN, HealthGPTConfig_L14_COM 5 | 6 | configs = { 7 | "HealthGPT-M3-COM": HealthGPTConfig_M3_COM(), 8 | "HealthGPT-M3-GEN": HealthGPTConfig_M3_GEN(), 9 | "HealthGPT-L14-COM": HealthGPTConfig_L14_COM() 10 | } 11 | 12 | agent = HealthGPT_Agent(configs=configs, model_name=None) 13 | 14 | # HealthGPT interface 15 | import gradio as gr 16 | from PIL import Image, ImageDraw 17 | 18 | def process_input(option, model_name, text, image): 19 | if not text.strip(): 20 | return gr.update(value="⚠️ Please input your question.", visible=True), None, gr.update(visible=True), gr.update(visible=False) 21 | try: 22 | if option == "Analyze Image": 23 | model_name = model_name + "-COM" 24 | try: 25 | agent.load_model(model_name=model_name) 26 | resp = agent.process(option, text, image) 27 | except Exception as e: 28 | agent.load_model(model_name=model_name) 29 | resp = agent.process(option, text, image) 30 | return resp, None, gr.update(visible=True), gr.update(visible=False) 31 | 32 | elif option == "Generate Image": 33 | model_name = model_name + "-GEN" 34 | try: 35 | agent.load_model(model_name=model_name) 36 | resp = agent.process(option, text, image) 37 | except Exception as e: 38 | agent.load_model(model_name=model_name) 39 | resp = agent.process(option, text, image) 40 | return None, resp, gr.update(visible=False), gr.update(visible=True) 41 | except Exception as e: 42 | print(traceback.format_exc()) 43 | return gr.update(value=f"⚠️ {e.args[0]}", visible=True), None, gr.update(visible=True), gr.update(visible=False) 44 | 45 | 46 | with gr.Blocks() as demo: 47 | # gr.Markdown("# 🖼️ HealthGPT") 48 | gr.Markdown("

🖼️ HealthGPT

") 49 | 50 | # Option A / B 51 | with gr.Row(): 52 | option = gr.Radio(["Analyze Image", "Generate Image"], label="🔍Choose the task", value="Analyze Image", interactive=True) 53 | model_name = gr.Radio(["HealthGPT-M3", "HealthGPT-L14"], label="🧠Choose the model", value="HealthGPT-M3", interactive=True) 54 | 55 | with gr.Row(): 56 | with gr.Column(): 57 | gr.Markdown("### 🔹 Input") 58 | text_input = gr.Textbox(label="Question", placeholder="Text here...", lines=3, value="Could you explain what this mass in the MRI means for my health? Is it very serious?") 59 | image_input = gr.Image(type="pil", label="Upload an image...") 60 | 61 | with gr.Column(): 62 | gr.Markdown("### 🔹 Output") 63 | process_button = gr.Button("🚀 Process", variant="primary") 64 | text_output = gr.Textbox(label="HealthGPT Answer", visible=True, lines=20) 65 | image_output = gr.Image(label="Generated Image", visible=False) 66 | 67 | process_button.click( 68 | process_input, 69 | inputs=[option, model_name, text_input, image_input], 70 | outputs=[text_output, image_output, text_output, image_output] # 用 gr.update() 代替 bool 71 | ) 72 | 73 | gr.Markdown("""### Terms of use 74 | By using this service, users are required to agree to the following terms: 75 | The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research. 76 | For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.""") 77 | 78 | demo.css = """footer {display: none !important;}""" 79 | 80 | # Start Gradio website 81 | demo.launch(server_name="0.0.0.0", server_port=5011, show_api=False) 82 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # HealthGPT config 2 | class HealthGPTConfig_M3_COM: 3 | model_name_or_path = "./Phi-3-mini-4k-instruct" 4 | dtype = "FP16" 5 | attn_implementation = None 6 | hlora_r = 64 7 | hlora_alpha = 128 8 | hlora_dropout = 0.0 9 | hlora_nums = 4 10 | vq_idx_nums = 8192 11 | instruct_template = "phi3_instruct" 12 | vit_path = "./clip-vit-large-patch14-336/" 13 | hlora_path = "./HealthGPT-M3/com_hlora_weights.bin" 14 | fusion_layer_path = None 15 | do_sample = False 16 | temperature = 0.0 17 | top_p = None 18 | num_beams = 1 19 | max_new_tokens = 2048 20 | task_type = "comprehension" 21 | 22 | 23 | class HealthGPTConfig_M3_GEN: 24 | model_name_or_path = "./Phi-3-mini-4k-instruct" 25 | dtype = "FP16" 26 | attn_implementation = None 27 | hlora_r = 256 28 | hlora_alpha = 512 29 | hlora_dropout = 0.0 30 | hlora_nums = 4 31 | vq_idx_nums = 8192 32 | instruct_template = "phi3_instruct" 33 | vit_path = "./clip-vit-large-patch14-336/" 34 | hlora_path = "./HealthGPT-M3/gen_hlora_weights.bin" 35 | fusion_layer_path = "./HealthGPT-M3/fusion_layer_weights.bin" 36 | do_sample = False 37 | temperature = 0.0 38 | top_p = None 39 | num_beams = 1 40 | max_new_tokens = 2048 41 | save_path = "output.png" 42 | task_type = "generation" 43 | 44 | 45 | class HealthGPTConfig_L14_COM: 46 | model_name_or_path = "./phi-4" 47 | dtype = "FP16" 48 | attn_implementation = None 49 | hlora_r = 32 50 | hlora_alpha = 64 51 | hlora_dropout = 0.0 52 | hlora_nums = 4 53 | vq_idx_nums = 8192 54 | instruct_template = "phi4_instruct" 55 | vit_path = "./clip-vit-large-patch14-336/" 56 | hlora_path = "./HealthGPT-L14/com_hlora_weights_phi4.bin" 57 | fusion_layer_path = None 58 | do_sample = False 59 | temperature = 0.0 60 | top_p = None 61 | num_beams = 1 62 | max_new_tokens = 2048 63 | task_type = "comprehension" 64 | -------------------------------------------------------------------------------- /dockerignore.txt: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | 19 | # Exclude some weights 20 | /openai 21 | /liuhaotian 22 | -------------------------------------------------------------------------------- /editorconfig.txt: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | # Unix-style newlines with a newline ending every file 4 | [*] 5 | end_of_line = lf 6 | insert_final_newline = true 7 | trim_trailing_whitespace = true 8 | charset = utf-8 9 | 10 | # 4 space indentation 11 | [*.{py,json}] 12 | indent_style = space 13 | indent_size = 4 14 | 15 | # 2 space indentation 16 | [*.{md,sh,yaml,yml}] 17 | indent_style = space 18 | indent_size = 2 -------------------------------------------------------------------------------- /gitattributes.txt: -------------------------------------------------------------------------------- 1 | # https://git-scm.com/docs/gitattributes 2 | 3 | # Set the default behavior, in case people don't have core.autocrlf set. 4 | # https://git-scm.com/docs/gitattributes#_end_of_line_conversion 5 | * text=auto 6 | 7 | # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes 8 | # Source files 9 | # ============ 10 | *.pxd text diff=python 11 | *.py text diff=python 12 | *.py3 text diff=python 13 | *.pyw text diff=python 14 | *.pyx text diff=python 15 | *.pyz text diff=python 16 | *.pyi text diff=python 17 | 18 | # Binary files 19 | # ============ 20 | *.db binary 21 | *.p binary 22 | *.pkl binary 23 | *.pickle binary 24 | *.pyc binary export-ignore 25 | *.pyo binary export-ignore 26 | *.pyd binary 27 | 28 | # Jupyter notebook 29 | *.ipynb text eol=lf 30 | -------------------------------------------------------------------------------- /gitignore.txt: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | 7 | # Log 8 | *.log 9 | *.log.* 10 | *.json 11 | *.jsonl 12 | 13 | # Data 14 | !**/alpaca-data-conversation.json 15 | 16 | # Editor 17 | .idea 18 | *.swp 19 | 20 | # Other 21 | .DS_Store 22 | wandb 23 | output 24 | 25 | checkpoints 26 | ckpts* 27 | 28 | .ipynb_checkpoints 29 | *.ipynb 30 | 31 | # DevContainer 32 | !.devcontainer/* 33 | 34 | # Demo 35 | serve_images/ 36 | -------------------------------------------------------------------------------- /images/Framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/images/Framework.png -------------------------------------------------------------------------------- /images/HealthGPT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/images/HealthGPT.png -------------------------------------------------------------------------------- /images/chatUI.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/images/chatUI.jpg -------------------------------------------------------------------------------- /images/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/images/intro.png -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaPhiForCausalLM -------------------------------------------------------------------------------- /llava/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /llava/__pycache__/constants.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/__pycache__/constants.cpython-310.pyc -------------------------------------------------------------------------------- /llava/__pycache__/conversation.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/__pycache__/conversation.cpython-310.pyc -------------------------------------------------------------------------------- /llava/__pycache__/mm_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/__pycache__/mm_utils.cpython-310.pyc -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /llava/demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/demo/__init__.py -------------------------------------------------------------------------------- /llava/demo/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/demo/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /llava/demo/com_infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_NAME_OR_PATH="microsoft/Phi-3-mini-4k-instruct" 4 | VIT_PATH="openai/clip-vit-large-patch14-336/" 5 | HLORA_PATH="com_hlora_weights.bin" 6 | FUSION_LAYER_PATH="fusion_layer_weights.bin" 7 | 8 | python3 com_infer.py \ 9 | --model_name_or_path "$MODEL_NAME_OR_PATH" \ 10 | --dtype "FP16" \ 11 | --hlora_r "64" \ 12 | --hlora_alpha "128" \ 13 | --hlora_nums "4" \ 14 | --vq_idx_nums "8192" \ 15 | --instruct_template "phi3_instruct" \ 16 | --vit_path "$VIT_PATH" \ 17 | --hlora_path "$HLORA_PATH" \ 18 | --fusion_layer_path "$FUSION_LAYER_PATH" \ 19 | --question "Your question" \ 20 | --img_path "path/to/image.jpg" \ 21 | -------------------------------------------------------------------------------- /llava/demo/com_infer_phi4.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_NAME_OR_PATH="microsoft/Phi-4" 4 | VIT_PATH="openai/clip-vit-large-patch14-336/" 5 | HLORA_PATH="com_hlora_weights_phi4.bin" 6 | 7 | python3 com_infer_phi4.py \ 8 | --model_name_or_path "$MODEL_NAME_OR_PATH" \ 9 | --dtype "FP16" \ 10 | --hlora_r "32" \ 11 | --hlora_alpha "64" \ 12 | --hlora_nums "4" \ 13 | --vq_idx_nums "8192" \ 14 | --instruct_template "phi4_instruct" \ 15 | --vit_path "$VIT_PATH" \ 16 | --hlora_path "$HLORA_PATH" \ 17 | --question "Your question" \ 18 | --img_path "path/to/image.jpg" 19 | -------------------------------------------------------------------------------- /llava/demo/com_infer_qwen2_5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_NAME_OR_PATH="Qwen/Qwen2.5-32B-Instruct" 4 | VIT_PATH="openai/clip-vit-large-patch14-336/" 5 | HLORA_PATH="com_hlora_weights_QWEN_32B.bin" 6 | 7 | python3 com_infer_qwen2_5.py \ 8 | --model_name_or_path "$MODEL_NAME_OR_PATH" \ 9 | --dtype "FP16" \ 10 | --hlora_r "32" \ 11 | --hlora_alpha "64" \ 12 | --hlora_nums "4" \ 13 | --vq_idx_nums "8192" \ 14 | --instruct_template "qwen_2" \ 15 | --vit_path "$VIT_PATH" \ 16 | --hlora_path "$HLORA_PATH" \ 17 | --question "Your question" \ 18 | --img_path "path/to/image.jpg" 19 | -------------------------------------------------------------------------------- /llava/demo/gen_infer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_NAME_OR_PATH="microsoft/Phi-3-mini-4k-instruct" 4 | VIT_PATH="openai/clip-vit-large-patch14-336/" 5 | HLORA_PATH="gen_hlora_weights.bin" 6 | FUSION_LAYER_PATH="fusion_layer_weights.bin" 7 | 8 | python3 gen_infer.py \ 9 | --model_name_or_path "$MODEL_NAME_OR_PATH" \ 10 | --dtype "FP16" \ 11 | --hlora_r "256" \ 12 | --hlora_alpha "512" \ 13 | --hlora_nums "4" \ 14 | --vq_idx_nums "8192" \ 15 | --instruct_template "phi3_instruct" \ 16 | --vit_path "$VIT_PATH" \ 17 | --hlora_path "$HLORA_PATH" \ 18 | --fusion_layer_path "$FUSION_LAYER_PATH" \ 19 | --question "Reconstruct the image." \ 20 | --img_path "path/to/image.jpg" \ 21 | --save_path "path/to/save.jpg" 22 | -------------------------------------------------------------------------------- /llava/demo/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | import tokenizers 4 | import os, sys 5 | from dataclasses import dataclass, field 6 | import argparse 7 | from PIL import Image 8 | 9 | def expand2square(pil_img, background_color): 10 | width, height = pil_img.size 11 | if width == height: 12 | return pil_img 13 | elif width > height: 14 | result = Image.new(pil_img.mode, (width, width), background_color) 15 | result.paste(pil_img, (0, (width - height) // 2)) 16 | return result 17 | else: 18 | result = Image.new(pil_img.mode, (height, height), background_color) 19 | result.paste(pil_img, ((height - width) // 2, 0)) 20 | return result 21 | 22 | def find_all_linear_names(model): 23 | cls = torch.nn.Linear 24 | lora_module_names = set() 25 | multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] 26 | for name, module in model.named_modules(): 27 | if any(mm_keyword in name for mm_keyword in multimodal_keywords): 28 | continue 29 | if isinstance(module, cls): 30 | names = name.split('.') 31 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 32 | 33 | if 'lm_head' in lora_module_names: # needed for 16-bit 34 | lora_module_names.remove('lm_head') 35 | return list(lora_module_names) 36 | 37 | def add_special_tokens_and_resize_model(tokenizer, model, vq_idx_nums=8192): 38 | if len(tokenizer.additional_special_tokens) != 0: 39 | return tokenizer.additional_special_tokens 40 | index_tokens = [f"" for i in range(vq_idx_nums)] 41 | special_tokens = { 42 | 'additional_special_tokens': [''] + index_tokens + [''] + [''] 43 | } 44 | num_new_tokens = tokenizer.add_special_tokens(special_tokens) 45 | model.resize_token_embeddings(len(tokenizer)) 46 | if num_new_tokens > 0: 47 | input_embeddings = model.get_input_embeddings().weight.data 48 | output_embeddings = model.get_output_embeddings().weight.data 49 | 50 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 51 | dim=0, keepdim=True) 52 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 53 | dim=0, keepdim=True) 54 | 55 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 56 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 57 | 58 | return num_new_tokens 59 | 60 | com_vision_args = argparse.Namespace( 61 | freeze_backbone=False, 62 | mm_patch_merge_type='flat', 63 | mm_projector_type='mlp2x_gelu', 64 | mm_use_im_patch_token=False, 65 | mm_use_im_start_end=False, 66 | mm_vision_select_feature='patch', 67 | mm_vision_select_layer=-2, 68 | model_name_or_path=None, 69 | pretrain_mm_mlp_adapter=None, 70 | tune_mm_mlp_adapter=False, 71 | version=None, 72 | vision_tower=None 73 | ) 74 | 75 | gen_vision_args = argparse.Namespace( 76 | freeze_backbone=False, 77 | mm_patch_merge_type='flat', 78 | mm_projector_type='mlp2x_gelu', 79 | mm_use_im_patch_token=False, 80 | mm_use_im_start_end=False, 81 | mm_vision_select_feature='patch', 82 | mm_vision_select_layer=1, 83 | model_name_or_path=None, 84 | pretrain_mm_mlp_adapter=None, 85 | tune_mm_mlp_adapter=False, 86 | version=None, 87 | vision_tower=None 88 | ) 89 | 90 | def load_weights(model, hlora_path, fusion_layer_path=None): 91 | hlora_weights = torch.load(hlora_path) 92 | hlora_unexpected_keys = model.load_state_dict(hlora_weights, strict=False)[1] 93 | if hlora_unexpected_keys: 94 | print(f"Warning: Unexpected keys in hlora checkpoint: {hlora_unexpected_keys}") 95 | 96 | if fusion_layer_path: 97 | fusion_layer_weights = torch.load(fusion_layer_path) 98 | fusion_layer_unexpected_keys = model.load_state_dict(fusion_layer_weights, strict=False)[1] 99 | if fusion_layer_unexpected_keys: 100 | print(f"Warning: Unexpected keys in fusion_layer checkpoint: {fusion_layer_unexpected_keys}") 101 | 102 | return model 103 | 104 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | NUM_SECONDS_TO_SLEEP = 3 11 | 12 | @ray.remote(num_cpus=4) 13 | def get_eval(content: str, max_tokens: int): 14 | while True: 15 | try: 16 | response = openai.ChatCompletion.create( 17 | model='gpt-4', 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except openai.error.RateLimitError: 30 | pass 31 | except Exception as e: 32 | print(e) 33 | time.sleep(NUM_SECONDS_TO_SLEEP) 34 | 35 | print('success!') 36 | return response['choices'][0]['message']['content'] 37 | 38 | 39 | def parse_score(review): 40 | try: 41 | score_pair = review.split('\n')[0] 42 | score_pair = score_pair.replace(',', ' ') 43 | sp = score_pair.split(' ') 44 | if len(sp) == 2: 45 | return [float(sp[0]), float(sp[1])] 46 | else: 47 | print('error', review) 48 | return [-1, -1] 49 | except Exception as e: 50 | print(e) 51 | print('error', review) 52 | return [-1, -1] 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 57 | parser.add_argument('-q', '--question') 58 | # parser.add_argument('-a', '--answer') 59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 60 | parser.add_argument('-r', '--rule') 61 | parser.add_argument('-o', '--output') 62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 63 | args = parser.parse_args() 64 | 65 | ray.init() 66 | 67 | f_q = open(os.path.expanduser(args.question)) 68 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 69 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 71 | 72 | review_file = open(f'{args.output}', 'w') 73 | 74 | js_list = [] 75 | handles = [] 76 | idx = 0 77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 78 | # if idx == 1: 79 | # break 80 | 81 | ques = json.loads(ques_js) 82 | ans1 = json.loads(ans1_js) 83 | ans2 = json.loads(ans2_js) 84 | 85 | category = json.loads(ques_js)['category'] 86 | if category in rule_dict: 87 | rule = rule_dict[category] 88 | else: 89 | rule = rule_dict['default'] 90 | prompt = rule['prompt'] 91 | role = rule['role'] 92 | content = (f'[Question]\n{ques["text"]}\n\n' 93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 95 | f'[System]\n{prompt}\n\n') 96 | js_list.append({ 97 | 'id': idx+1, 98 | 'question_id': ques['question_id'], 99 | 'answer1_id': ans1['answer_id'], 100 | 'answer2_id': ans2['answer_id'], 101 | 'category': category}) 102 | idx += 1 103 | handles.append(get_eval.remote(content, args.max_tokens)) 104 | # To avoid the rate limit set by OpenAI 105 | time.sleep(NUM_SECONDS_TO_SLEEP) 106 | 107 | reviews = ray.get(handles) 108 | for idx, review in enumerate(reviews): 109 | scores = parse_score(review) 110 | js_list[idx]['content'] = review 111 | js_list[idx]['tuple'] = scores 112 | review_file.write(json.dumps(js_list[idx]) + '\n') 113 | review_file.close() 114 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | cap_str = '\n'.join(inst['captions']) 86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 87 | 88 | category = json.loads(ques_js)['category'] 89 | if category in rule_dict: 90 | rule = rule_dict[category] 91 | else: 92 | assert False, f"Visual QA category not found in rule file: {category}." 93 | prompt = rule['prompt'] 94 | role = rule['role'] 95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 96 | f'[Question]\n{ques["text"]}\n\n' 97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 99 | f'[System]\n{prompt}\n\n') 100 | cur_js = { 101 | 'id': idx+1, 102 | 'question_id': ques['question_id'], 103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 105 | 'category': category 106 | } 107 | if idx >= len(cur_reviews): 108 | review = get_eval(content, args.max_tokens) 109 | scores = parse_score(review) 110 | cur_js['content'] = review 111 | cur_js['tuple'] = scores 112 | review_file.write(json.dumps(cur_js) + '\n') 113 | review_file.flush() 114 | else: 115 | print(f'Skipping {idx} as we already have it.') 116 | idx += 1 117 | print(idx) 118 | review_file.close() 119 | -------------------------------------------------------------------------------- /llava/eval/eval_pope.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | def eval_pope(answers, label_file): 6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')] 7 | 8 | for answer in answers: 9 | text = answer['text'] 10 | 11 | # Only keep the first sentence 12 | if text.find('.') != -1: 13 | text = text.split('.')[0] 14 | 15 | text = text.replace(',', '') 16 | words = text.split(' ') 17 | if 'No' in words or 'not' in words or 'no' in words: 18 | answer['text'] = 'no' 19 | else: 20 | answer['text'] = 'yes' 21 | 22 | for i in range(len(label_list)): 23 | if label_list[i] == 'no': 24 | label_list[i] = 0 25 | else: 26 | label_list[i] = 1 27 | 28 | pred_list = [] 29 | for answer in answers: 30 | if answer['text'] == 'no': 31 | pred_list.append(0) 32 | else: 33 | pred_list.append(1) 34 | 35 | pos = 1 36 | neg = 0 37 | yes_ratio = pred_list.count(1) / len(pred_list) 38 | 39 | TP, TN, FP, FN = 0, 0, 0, 0 40 | for pred, label in zip(pred_list, label_list): 41 | if pred == pos and label == pos: 42 | TP += 1 43 | elif pred == pos and label == neg: 44 | FP += 1 45 | elif pred == neg and label == neg: 46 | TN += 1 47 | elif pred == neg and label == pos: 48 | FN += 1 49 | 50 | print('TP\tFP\tTN\tFN\t') 51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN)) 52 | 53 | precision = float(TP) / float(TP + FP) 54 | recall = float(TP) / float(TP + FN) 55 | f1 = 2*precision*recall / (precision + recall) 56 | acc = (TP + TN) / (TP + TN + FP + FN) 57 | print('Accuracy: {}'.format(acc)) 58 | print('Precision: {}'.format(precision)) 59 | print('Recall: {}'.format(recall)) 60 | print('F1 score: {}'.format(f1)) 61 | print('Yes ratio: {}'.format(yes_ratio)) 62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) ) 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--annotation-dir", type=str) 67 | parser.add_argument("--question-file", type=str) 68 | parser.add_argument("--result-file", type=str) 69 | args = parser.parse_args() 70 | 71 | questions = [json.loads(line) for line in open(args.question_file)] 72 | questions = {question['question_id']: question for question in questions} 73 | answers = [json.loads(q) for q in open(args.result_file)] 74 | for file in os.listdir(args.annotation_dir): 75 | assert file.startswith('coco_pope_') 76 | assert file.endswith('.json') 77 | category = file[10:-5] 78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category] 79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers))) 80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file)) 81 | print("====================================") 82 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base-dir', type=str) 11 | parser.add_argument('--result-file', type=str) 12 | parser.add_argument('--output-file', type=str) 13 | parser.add_argument('--output-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return -1 36 | return random.choice(range(len(choices))) 37 | 38 | 39 | if __name__ == "__main__": 40 | args = get_args() 41 | 42 | base_dir = args.base_dir 43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 44 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 45 | predictions = [json.loads(line) for line in open(args.result_file)] 46 | predictions = {pred['question_id']: pred for pred in predictions} 47 | split_problems = {idx: problems[idx] for idx in split_indices} 48 | 49 | results = {'correct': [], 'incorrect': []} 50 | sqa_results = {} 51 | sqa_results['acc'] = None 52 | sqa_results['correct'] = None 53 | sqa_results['count'] = None 54 | sqa_results['results'] = {} 55 | sqa_results['outputs'] = {} 56 | 57 | for prob_id, prob in split_problems.items(): 58 | if prob_id not in predictions: 59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'} 60 | pred_text = 'FAILED' 61 | else: 62 | pred = predictions[prob_id] 63 | pred_text = pred['text'] 64 | 65 | if pred_text in args.options: 66 | answer = pred_text 67 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ": 68 | answer = pred_text[0] 69 | else: 70 | pattern = re.compile(r'The answer is ([A-Z]).') 71 | res = pattern.findall(pred_text) 72 | if len(res) == 1: 73 | answer = res[0] # 'A', 'B', ... 74 | else: 75 | answer = "FAILED" 76 | 77 | pred_idx = get_pred_idx(answer, prob['choices'], args.options) 78 | 79 | analysis = { 80 | 'question_id': prob_id, 81 | 'parsed_ans': answer, 82 | 'ground_truth': args.options[prob['answer']], 83 | 'question': pred['prompt'], 84 | 'pred': pred_text, 85 | 'is_multimodal': '' in pred['prompt'], 86 | } 87 | 88 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) 89 | sqa_results['outputs'][prob_id] = pred_text 90 | 91 | if pred_idx == prob['answer']: 92 | results['correct'].append(analysis) 93 | else: 94 | results['incorrect'].append(analysis) 95 | 96 | correct = len(results['correct']) 97 | total = len(results['correct']) + len(results['incorrect']) 98 | 99 | ###### IMG ###### 100 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']]) 101 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']]) 102 | multimodal_total = multimodal_correct + multimodal_incorrect 103 | ###### IMG ###### 104 | 105 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%') 106 | 107 | sqa_results['acc'] = correct / total * 100 108 | sqa_results['correct'] = correct 109 | sqa_results['count'] = total 110 | 111 | with open(args.output_file, 'w') as f: 112 | json.dump(results, f, indent=2) 113 | with open(args.output_result, 'w') as f: 114 | json.dump(sqa_results, f, indent=2) 115 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /llava/eval/eval_textvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import re 5 | 6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--annotation-file', type=str) 12 | parser.add_argument('--result-file', type=str) 13 | parser.add_argument('--result-dir', type=str) 14 | return parser.parse_args() 15 | 16 | 17 | def prompt_processor(prompt): 18 | if prompt.startswith('OCR tokens: '): 19 | pattern = r"Question: (.*?) Short answer:" 20 | match = re.search(pattern, prompt, re.DOTALL) 21 | question = match.group(1) 22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: 23 | if prompt.startswith('Reference OCR token:'): 24 | question = prompt.split('\n')[1] 25 | else: 26 | question = prompt.split('\n')[0] 27 | elif len(prompt.split('\n')) == 2: 28 | question = prompt.split('\n')[0] 29 | else: 30 | assert False 31 | 32 | return question.lower() 33 | 34 | 35 | def eval_single(annotation_file, result_file): 36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0] 37 | print(experiment_name) 38 | annotations = json.load(open(annotation_file))['data'] 39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} 40 | results = [json.loads(line) for line in open(result_file)] 41 | 42 | pred_list = [] 43 | for result in results: 44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] 45 | pred_list.append({ 46 | "pred_answer": result['text'], 47 | "gt_answers": annotation['answers'], 48 | }) 49 | 50 | evaluator = TextVQAAccuracyEvaluator() 51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) 52 | 53 | 54 | if __name__ == "__main__": 55 | args = get_args() 56 | 57 | if args.result_file is not None: 58 | eval_single(args.annotation_file, args.result_file) 59 | 60 | if args.result_dir is not None: 61 | for result_file in sorted(os.listdir(args.result_dir)): 62 | if not result_file.endswith('.jsonl'): 63 | print(f'Skipping {result_file}') 64 | continue 65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file)) 66 | -------------------------------------------------------------------------------- /llava/eval/generate_webpage_data_from_table.py: -------------------------------------------------------------------------------- 1 | """Generate json file for webpage.""" 2 | import json 3 | import os 4 | import re 5 | 6 | # models = ['llama', 'alpaca', 'gpt35', 'bard'] 7 | models = ['vicuna'] 8 | 9 | 10 | def read_jsonl(path: str, key: str=None): 11 | data = [] 12 | with open(os.path.expanduser(path)) as f: 13 | for line in f: 14 | if not line: 15 | continue 16 | data.append(json.loads(line)) 17 | if key is not None: 18 | data.sort(key=lambda x: x[key]) 19 | data = {item[key]: item for item in data} 20 | return data 21 | 22 | 23 | def trim_hanging_lines(s: str, n: int) -> str: 24 | s = s.strip() 25 | for _ in range(n): 26 | s = s.split('\n', 1)[1].strip() 27 | return s 28 | 29 | 30 | if __name__ == '__main__': 31 | questions = read_jsonl('table/question.jsonl', key='question_id') 32 | 33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') 34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') 35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') 36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') 37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') 38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') 39 | 40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') 41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') 42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') 43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') 44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') 45 | 46 | records = [] 47 | for qid in questions.keys(): 48 | r = { 49 | 'id': qid, 50 | 'category': questions[qid]['category'], 51 | 'question': questions[qid]['text'], 52 | 'answers': { 53 | # 'alpaca': alpaca_answers[qid]['text'], 54 | # 'llama': llama_answers[qid]['text'], 55 | # 'bard': bard_answers[qid]['text'], 56 | # 'gpt35': gpt35_answers[qid]['text'], 57 | 'vicuna': vicuna_answers[qid]['text'], 58 | 'ours': ours_answers[qid]['text'], 59 | }, 60 | 'evaluations': { 61 | # 'alpaca': review_alpaca[qid]['text'], 62 | # 'llama': review_llama[qid]['text'], 63 | # 'bard': review_bard[qid]['text'], 64 | 'vicuna': review_vicuna[qid]['content'], 65 | # 'gpt35': review_gpt35[qid]['text'], 66 | }, 67 | 'scores': { 68 | 'vicuna': review_vicuna[qid]['tuple'], 69 | # 'alpaca': review_alpaca[qid]['score'], 70 | # 'llama': review_llama[qid]['score'], 71 | # 'bard': review_bard[qid]['score'], 72 | # 'gpt35': review_gpt35[qid]['score'], 73 | }, 74 | } 75 | 76 | # cleanup data 77 | cleaned_evals = {} 78 | for k, v in r['evaluations'].items(): 79 | v = v.strip() 80 | lines = v.split('\n') 81 | # trim the first line if it's a pair of numbers 82 | if re.match(r'\d+[, ]+\d+', lines[0]): 83 | lines = lines[1:] 84 | v = '\n'.join(lines) 85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') 86 | 87 | r['evaluations'] = cleaned_evals 88 | records.append(r) 89 | 90 | # Reorder the records, this is optional 91 | for r in records: 92 | if r['id'] <= 20: 93 | r['id'] += 60 94 | else: 95 | r['id'] -= 20 96 | for r in records: 97 | if r['id'] <= 50: 98 | r['id'] += 10 99 | elif 50 < r['id'] <= 60: 100 | r['id'] -= 50 101 | for r in records: 102 | if r['id'] == 7: 103 | r['id'] = 1 104 | elif r['id'] < 7: 105 | r['id'] += 1 106 | 107 | records.sort(key=lambda x: x['id']) 108 | 109 | # Write to file 110 | with open('webpage/data.json', 'w') as f: 111 | json.dump({'questions': records, 'models': models}, f, indent=2) 112 | -------------------------------------------------------------------------------- /llava/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.conversation import default_conversation 10 | from llava.utils import disable_torch_init 11 | 12 | 13 | @torch.inference_mode() 14 | def eval_model(model_name, questions_file, answers_file): 15 | # Model 16 | disable_torch_init() 17 | model_name = os.path.expanduser(model_name) 18 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 19 | model = AutoModelForCausalLM.from_pretrained(model_name, 20 | torch_dtype=torch.float16).cuda() 21 | 22 | 23 | ques_file = open(os.path.expanduser(questions_file), "r") 24 | ans_file = open(os.path.expanduser(answers_file), "w") 25 | for i, line in enumerate(tqdm(ques_file)): 26 | idx = json.loads(line)["question_id"] 27 | qs = json.loads(line)["text"] 28 | cat = json.loads(line)["category"] 29 | conv = default_conversation.copy() 30 | conv.append_message(conv.roles[0], qs) 31 | prompt = conv.get_prompt() 32 | inputs = tokenizer([prompt]) 33 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 34 | output_ids = model.generate( 35 | input_ids, 36 | do_sample=True, 37 | use_cache=True, 38 | temperature=0.7, 39 | max_new_tokens=1024,) 40 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 41 | try: 42 | index = outputs.index(conv.sep, len(prompt)) 43 | except ValueError: 44 | outputs += conv.sep 45 | index = outputs.index(conv.sep, len(prompt)) 46 | 47 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 48 | ans_id = shortuuid.uuid() 49 | ans_file.write(json.dumps({"question_id": idx, 50 | "text": outputs, 51 | "answer_id": ans_id, 52 | "model_id": model_name, 53 | "metadata": {}}) + "\n") 54 | ans_file.flush() 55 | ans_file.close() 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 60 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 61 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 62 | args = parser.parse_args() 63 | 64 | eval_model(args.model_name, args.question_file, args.answers_file) 65 | -------------------------------------------------------------------------------- /llava/eval/model_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for line in tqdm(questions): 42 | idx = line["question_id"] 43 | image_file = line["image"] 44 | qs = line["text"] 45 | cur_prompt = qs 46 | if model.config.mm_use_im_start_end: 47 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 48 | else: 49 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 50 | 51 | conv = conv_templates[args.conv_mode].copy() 52 | conv.append_message(conv.roles[0], qs) 53 | conv.append_message(conv.roles[1], None) 54 | prompt = conv.get_prompt() 55 | 56 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 57 | 58 | image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB') 59 | image_tensor = process_images([image], image_processor, model.config)[0] 60 | 61 | with torch.inference_mode(): 62 | output_ids = model.generate( 63 | input_ids, 64 | images=image_tensor.unsqueeze(0).half().cuda(), 65 | image_sizes=[image.size], 66 | do_sample=True if args.temperature > 0 else False, 67 | temperature=args.temperature, 68 | top_p=args.top_p, 69 | num_beams=args.num_beams, 70 | # no_repeat_ngram_size=3, 71 | max_new_tokens=1024, 72 | use_cache=True) 73 | 74 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 75 | 76 | ans_id = shortuuid.uuid() 77 | ans_file.write(json.dumps({"question_id": idx, 78 | "prompt": cur_prompt, 79 | "text": outputs, 80 | "answer_id": ans_id, 81 | "model_id": model_name, 82 | "metadata": {}}) + "\n") 83 | ans_file.flush() 84 | ans_file.close() 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 89 | parser.add_argument("--model-base", type=str, default=None) 90 | parser.add_argument("--image-folder", type=str, default="") 91 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 92 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 93 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 94 | parser.add_argument("--num-chunks", type=int, default=1) 95 | parser.add_argument("--chunk-idx", type=int, default=0) 96 | parser.add_argument("--temperature", type=float, default=0.2) 97 | parser.add_argument("--top_p", type=float, default=None) 98 | parser.add_argument("--num_beams", type=int, default=1) 99 | args = parser.parse_args() 100 | 101 | eval_model(args) 102 | -------------------------------------------------------------------------------- /llava/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import concurrent.futures 8 | 9 | import openai 10 | import tqdm 11 | import shortuuid 12 | 13 | MODEL = 'gpt-3.5-turbo' 14 | MODEL_ID = 'gpt-3.5-turbo:20230327' 15 | 16 | def get_answer(question_id: int, question: str, max_tokens: int): 17 | ans = { 18 | 'answer_id': shortuuid.uuid(), 19 | 'question_id': question_id, 20 | 'model_id': MODEL_ID, 21 | } 22 | for _ in range(3): 23 | try: 24 | response = openai.ChatCompletion.create( 25 | model=MODEL, 26 | messages=[{ 27 | 'role': 'system', 28 | 'content': 'You are a helpful assistant.' 29 | }, { 30 | 'role': 'user', 31 | 'content': question, 32 | }], 33 | max_tokens=max_tokens, 34 | ) 35 | ans['text'] = response['choices'][0]['message']['content'] 36 | return ans 37 | except Exception as e: 38 | print('[ERROR]', e) 39 | ans['text'] = '#ERROR#' 40 | time.sleep(1) 41 | return ans 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 46 | parser.add_argument('-q', '--question') 47 | parser.add_argument('-o', '--output') 48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 49 | args = parser.parse_args() 50 | 51 | questions_dict = {} 52 | with open(os.path.expanduser(args.question)) as f: 53 | for line in f: 54 | if not line: 55 | continue 56 | q = json.loads(line) 57 | questions_dict[q['question_id']] = q['text'] 58 | 59 | answers = [] 60 | 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 62 | futures = [] 63 | for qid, question in questions_dict.items(): 64 | future = executor.submit(get_answer, qid, question, args.max_tokens) 65 | futures.append(future) 66 | 67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 68 | answers.append(future.result()) 69 | 70 | answers.sort(key=lambda x: x['question_id']) 71 | 72 | with open(os.path.expanduser(args.output), 'w') as f: 73 | table = [json.dumps(ans) for ans in answers] 74 | f.write('\n'.join(table)) 75 | -------------------------------------------------------------------------------- /llava/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | import argparse 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 11 | parser.add_argument('-d', '--dir', default=None) 12 | parser.add_argument('-v', '--version', default=None) 13 | parser.add_argument('-s', '--select', nargs='*', default=None) 14 | parser.add_argument('-f', '--files', nargs='*', default=[]) 15 | parser.add_argument('-i', '--ignore', nargs='*', default=[]) 16 | return parser.parse_args() 17 | 18 | 19 | if __name__ == '__main__': 20 | args = parse_args() 21 | 22 | if args.ignore is not None: 23 | args.ignore = [int(x) for x in args.ignore] 24 | 25 | if len(args.files) > 0: 26 | review_files = args.files 27 | else: 28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)] 29 | 30 | for review_file in sorted(review_files): 31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') 32 | if args.select is not None and any(x not in config for x in args.select): 33 | continue 34 | if '0613' in config: 35 | version = '0613' 36 | else: 37 | version = '0314' 38 | if args.version is not None and args.version != version: 39 | continue 40 | scores = defaultdict(list) 41 | print(config) 42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: 43 | for review_str in f: 44 | review = json.loads(review_str) 45 | if review['question_id'] in args.ignore: 46 | continue 47 | if 'category' in review: 48 | scores[review['category']].append(review['tuple']) 49 | scores['all'].append(review['tuple']) 50 | else: 51 | if 'tuple' in review: 52 | scores['all'].append(review['tuple']) 53 | else: 54 | scores['all'].append(review['score']) 55 | for k, v in sorted(scores.items()): 56 | stats = np.asarray(v).mean(0).tolist() 57 | stats = [round(x, 3) for x in stats] 58 | # print(k, stats, round(stats[1]/stats[0]*100, 1)) 59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1)) 60 | print('=================================') 61 | -------------------------------------------------------------------------------- /llava/eval/table/model.jsonl: -------------------------------------------------------------------------------- 1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"} 2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"} 3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"} 4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"} 5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"} 6 | -------------------------------------------------------------------------------- /llava/eval/table/reviewer.jsonl: -------------------------------------------------------------------------------- 1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"} 2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"} 3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/alpaca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/eval/webpage/figures/alpaca.png -------------------------------------------------------------------------------- /llava/eval/webpage/figures/bard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/eval/webpage/figures/bard.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/chatgpt.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/llama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/eval/webpage/figures/llama.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/vicuna.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/eval/webpage/figures/vicuna.jpeg -------------------------------------------------------------------------------- /llava/eval/webpage/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 3 | background-color: #f8f9fa; 4 | } 5 | 6 | .navbar-dark .navbar-nav .nav-link { 7 | color: #f1cf68; 8 | font-size: 1.1rem; 9 | padding: 0.5rem 0.6rem; 10 | } 11 | 12 | .card-header { 13 | font-weight: bold; 14 | } 15 | 16 | .card { 17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); 18 | transition: 0.3s; 19 | } 20 | 21 | .card:hover { 22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); 23 | } 24 | 25 | button { 26 | transition: background-color 0.3s; 27 | } 28 | 29 | button:hover { 30 | background-color: #007bff; 31 | } 32 | 33 | @media (max-width: 767px) { 34 | .form-row .form-group { 35 | margin-bottom: 10px; 36 | } 37 | } 38 | 39 | /* Extra styles */ 40 | 41 | .expandable-card .card-text-container { 42 | max-height: 200px; 43 | overflow-y: hidden; 44 | position: relative; 45 | } 46 | 47 | .expandable-card.expanded .card-text-container { 48 | max-height: none; 49 | } 50 | 51 | .expand-btn { 52 | position: relative; 53 | display: none; 54 | background-color: rgba(255, 255, 255, 0.8); 55 | color: #510c75; 56 | border-color: transparent; 57 | } 58 | 59 | .expand-btn:hover { 60 | background-color: rgba(200, 200, 200, 0.8); 61 | text-decoration: none; 62 | border-color: transparent; 63 | color: #510c75; 64 | } 65 | 66 | .expand-btn:focus { 67 | outline: none; 68 | text-decoration: none; 69 | } 70 | 71 | .expandable-card:not(.expanded) .card-text-container:after { 72 | content: ""; 73 | position: absolute; 74 | bottom: 0; 75 | left: 0; 76 | width: 100%; 77 | height: 90px; 78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); 79 | } 80 | 81 | .expandable-card:not(.expanded) .expand-btn { 82 | margin-top: -40px; 83 | } 84 | 85 | .card-body { 86 | padding-bottom: 5px; 87 | } 88 | 89 | .vertical-flex-layout { 90 | justify-content: center; 91 | align-items: center; 92 | height: 100%; 93 | display: flex; 94 | flex-direction: column; 95 | gap: 5px; 96 | } 97 | 98 | .figure-img { 99 | max-width: 100%; 100 | height: auto; 101 | } 102 | 103 | .adjustable-font-size { 104 | font-size: calc(0.5rem + 2vw); 105 | } 106 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig 4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig 5 | from .language_model.llava_phi3 import LlavaPhiForCausalLM, LlavaPhiConfig 6 | from .language_model.llava_qwen import LlavaQwen2ForCausalLM 7 | except: 8 | pass 9 | -------------------------------------------------------------------------------- /llava/model/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/__pycache__/llava_arch.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/__pycache__/llava_arch.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/__pycache__/llava_mpt.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/language_model/__pycache__/llava_mpt.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/__pycache__/llava_phi3.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/language_model/__pycache__/llava_phi3.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, \ 21 | MptConfig, MptForCausalLM, MptModel 22 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 23 | 24 | 25 | class LlavaMptConfig(MptConfig): 26 | model_type = "llava_mpt" 27 | 28 | 29 | class LlavaMptModel(LlavaMetaModel, MptModel): 30 | config_class = LlavaMptConfig 31 | 32 | def __init__(self, config: MptConfig): 33 | config.hidden_size = config.d_model 34 | super(LlavaMptModel, self).__init__(config) 35 | 36 | def embed_tokens(self, x): 37 | return self.wte(x) 38 | 39 | 40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaMptConfig 42 | supports_gradient_checkpointing = True 43 | 44 | def __init__(self, config): 45 | super(MptForCausalLM, self).__init__(config) 46 | 47 | self.transformer = LlavaMptModel(config) 48 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.transformer 55 | 56 | def _set_gradient_checkpointing(self, module, value=False): 57 | if isinstance(module, LlavaMptModel): 58 | module.gradient_checkpointing = value 59 | 60 | def forward( 61 | self, 62 | input_ids: Optional[torch.LongTensor] = None, 63 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 64 | attention_mask: Optional[torch.Tensor] = None, 65 | inputs_embeds: Optional[torch.Tensor] = None, 66 | labels: Optional[torch.Tensor] = None, 67 | use_cache: Optional[bool] = None, 68 | output_attentions: Optional[bool] = None, 69 | output_hidden_states: Optional[bool] = None, 70 | return_dict: Optional[bool] = None, 71 | images=None): 72 | 73 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 74 | 75 | return super().forward( 76 | input_ids, 77 | past_key_values=past_key_values, 78 | attention_mask=attention_mask, 79 | inputs_embeds=inputs_embeds, 80 | labels=labels, 81 | use_cache=use_cache, 82 | output_attentions=output_attentions, 83 | output_hidden_states=output_hidden_states, 84 | return_dict=return_dict, 85 | ) 86 | 87 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 88 | images = kwargs.pop("images", None) 89 | _inputs = super().prepare_inputs_for_generation( 90 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 91 | ) 92 | _inputs['images'] = images 93 | return _inputs 94 | 95 | 96 | AutoConfig.register("llava_mpt", LlavaMptConfig) 97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) 98 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | use_s2 = getattr(vision_tower_cfg, 's2', False) 9 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 10 | if use_s2: 11 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 12 | else: 13 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 14 | 15 | raise ValueError(f'Unknown vision tower: {vision_tower}') 16 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/peft/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | __version__ = "0.3.0.dev0" 21 | 22 | from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model 23 | from .peft_model import ( 24 | PeftModel, 25 | PeftModelForCausalLM, 26 | PeftModelForSeq2SeqLM, 27 | PeftModelForSequenceClassification, 28 | PeftModelForTokenClassification, 29 | ) 30 | from .tuners import ( 31 | LoraConfig, 32 | LoraModel, 33 | PrefixEncoder, 34 | PrefixTuningConfig, 35 | PromptEmbedding, 36 | PromptEncoder, 37 | PromptEncoderConfig, 38 | PromptEncoderReparameterizationType, 39 | PromptTuningConfig, 40 | PromptTuningInit, 41 | ) 42 | from .utils import ( 43 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 44 | PeftConfig, 45 | PeftType, 46 | PromptLearningConfig, 47 | TaskType, 48 | bloom_model_postprocess_past_key_value, 49 | get_peft_model_state_dict, 50 | # prepare_model_for_int8_training, 51 | set_peft_model_state_dict, 52 | shift_tokens_right, 53 | ) 54 | -------------------------------------------------------------------------------- /llava/peft/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /llava/peft/__pycache__/mapping.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/__pycache__/mapping.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/__pycache__/mapping.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/__pycache__/mapping.cpython-39.pyc -------------------------------------------------------------------------------- /llava/peft/__pycache__/peft_model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/__pycache__/peft_model.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/__pycache__/peft_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/__pycache__/peft_model.cpython-39.pyc -------------------------------------------------------------------------------- /llava/peft/tuners/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .lora import LoraConfig, LoraModel 21 | from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType 22 | from .prefix_tuning import PrefixEncoder, PrefixTuningConfig 23 | from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit 24 | -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/lora.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/lora.cpython-310.pyc.139831930933664: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.139831930933664 -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140023258402944: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140023258402944 -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140104265109632: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140104265109632 -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140108281172096: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140108281172096 -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140160632344704: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140160632344704 -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140480966316160: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140480966316160 -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140577961184688: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140577961184688 -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/lora.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-39.pyc -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/lora_moe.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora_moe.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/p_tuning.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/p_tuning.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/p_tuning.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/p_tuning.cpython-39.pyc -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/prefix_tuning.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/prefix_tuning.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/prefix_tuning.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/prefix_tuning.cpython-39.pyc -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/prompt_tuning.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/prompt_tuning.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/tuners/__pycache__/prompt_tuning.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/prompt_tuning.cpython-39.pyc -------------------------------------------------------------------------------- /llava/peft/tuners/prefix_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from dataclasses import dataclass, field 18 | 19 | import torch 20 | 21 | from ..utils import PeftType, PromptLearningConfig 22 | 23 | 24 | @dataclass 25 | class PrefixTuningConfig(PromptLearningConfig): 26 | """ 27 | This is the configuration class to store the configuration of a [`~peft.PrefixEncoder`]. 28 | 29 | Args: 30 | encoder_hidden_size (`int`): The hidden size of the prompt encoder. 31 | prefix_projection (`bool`): Whether to project the prefix embeddings. 32 | """ 33 | 34 | encoder_hidden_size: int = field( 35 | default=None, 36 | metadata={"help": "The hidden size of the encoder"}, 37 | ) 38 | prefix_projection: bool = field( 39 | default=False, 40 | metadata={"help": "Whether to project the prefix tokens"}, 41 | ) 42 | 43 | def __post_init__(self): 44 | self.peft_type = PeftType.PREFIX_TUNING 45 | 46 | 47 | # Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py 48 | # with some refactor 49 | class PrefixEncoder(torch.nn.Module): 50 | r""" 51 | The torch.nn model to encode the prefix 52 | 53 | Args: 54 | config ([`PrefixTuningConfig`]): The configuration of the prefix encoder. 55 | 56 | Example:: 57 | 58 | >>> from peft import PrefixEncoder, PrefixTuningConfig >>> config = PrefixTuningConfig( 59 | peft_type="PREFIX_TUNING", task_type="SEQ_2_SEQ_LM", num_virtual_tokens=20, token_dim=768, 60 | num_transformer_submodules=1, num_attention_heads=12, num_layers=12, encoder_hidden_size=768 61 | ) 62 | >>> prefix_encoder = PrefixEncoder(config) 63 | 64 | 65 | **Attributes**: 66 | - **embedding** (`torch.nn.Embedding`) -- 67 | The embedding layer of the prefix encoder. 68 | - **transform** (`torch.nn.Sequential`) -- The 69 | two-layer MLP to transform the prefix embeddings if `prefix_projection` is `True`. 70 | - **prefix_projection** (`bool`) -- Whether to project the prefix embeddings. 71 | 72 | Input shape: (batch_size, num_virtual_tokens) 73 | 74 | Output shape: (batch_size, num_virtual_tokens, 2*layers*hidden) 75 | """ 76 | 77 | def __init__(self, config): 78 | super().__init__() 79 | self.prefix_projection = config.prefix_projection 80 | token_dim = config.token_dim 81 | num_layers = config.num_layers 82 | encoder_hidden_size = config.encoder_hidden_size 83 | num_virtual_tokens = config.num_virtual_tokens 84 | if self.prefix_projection and not config.inference_mode: 85 | # Use a two-layer MLP to encode the prefix 86 | self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim) 87 | self.transform = torch.nn.Sequential( 88 | torch.nn.Linear(token_dim, encoder_hidden_size), 89 | torch.nn.Tanh(), 90 | torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim), 91 | ) 92 | else: 93 | self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim) 94 | 95 | def forward(self, prefix: torch.Tensor): 96 | if self.prefix_projection: 97 | prefix_tokens = self.embedding(prefix) 98 | past_key_values = self.transform(prefix_tokens) 99 | else: 100 | past_key_values = self.embedding(prefix) 101 | return past_key_values 102 | -------------------------------------------------------------------------------- /llava/peft/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .adapters_utils import CONFIG_NAME, WEIGHTS_NAME 21 | from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType 22 | from .other import ( 23 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 24 | _set_trainable, 25 | bloom_model_postprocess_past_key_value, 26 | # prepare_model_for_int8_training, 27 | shift_tokens_right, 28 | transpose, 29 | ) 30 | from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict 31 | -------------------------------------------------------------------------------- /llava/peft/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /llava/peft/utils/__pycache__/adapters_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/adapters_utils.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/utils/__pycache__/adapters_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/adapters_utils.cpython-39.pyc -------------------------------------------------------------------------------- /llava/peft/utils/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/utils/__pycache__/config.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/config.cpython-39.pyc -------------------------------------------------------------------------------- /llava/peft/utils/__pycache__/other.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/other.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/utils/__pycache__/other.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/other.cpython-39.pyc -------------------------------------------------------------------------------- /llava/peft/utils/__pycache__/save_and_load.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/save_and_load.cpython-310.pyc -------------------------------------------------------------------------------- /llava/peft/utils/__pycache__/save_and_load.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/save_and_load.cpython-39.pyc -------------------------------------------------------------------------------- /llava/peft/utils/adapters_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | WEIGHTS_NAME = "adapter_model.bin" 16 | CONFIG_NAME = "adapter_config.json" 17 | 18 | # TODO: add automapping and superclass here? 19 | -------------------------------------------------------------------------------- /llava/peft/utils/other.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | 19 | # needed for prefix-tuning of bloom model 20 | def bloom_model_postprocess_past_key_value(past_key_values): 21 | past_key_values = torch.cat(past_key_values) 22 | total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape 23 | keys = past_key_values[: total_layers // 2] 24 | keys = keys.transpose(2, 3).reshape( 25 | total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens 26 | ) 27 | values = past_key_values[total_layers // 2 :] 28 | values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim) 29 | 30 | return tuple(zip(keys, values)) 31 | 32 | 33 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = { 34 | "bloom": bloom_model_postprocess_past_key_value, 35 | } 36 | 37 | 38 | # copied from transformers.models.bart.modeling_bart 39 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 40 | """ 41 | Shift input ids one token to the right. 42 | 43 | Args: 44 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids 45 | pad_token_id (`int`): The id of the `padding` token. 46 | decoder_start_token_id (`int`): The id of the `start` token. 47 | """ 48 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 49 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 50 | shifted_input_ids[:, 0] = decoder_start_token_id 51 | 52 | if pad_token_id is None: 53 | raise ValueError("self.model.config.pad_token_id has to be defined.") 54 | # replace possible -100 values in labels by `pad_token_id` 55 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 56 | 57 | return shifted_input_ids 58 | 59 | 60 | def _set_trainable(model): 61 | if model.modules_to_save is not None: 62 | for name, param in model.named_parameters(): 63 | if any(module_name in name for module_name in model.modules_to_save): 64 | param.requires_grad = True 65 | 66 | 67 | def fsdp_auto_wrap_policy(model): 68 | import functools 69 | import os 70 | 71 | from accelerate import FullyShardedDataParallelPlugin 72 | from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy 73 | 74 | from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder 75 | 76 | def lambda_policy_fn(module): 77 | if ( 78 | len(list(module.named_children())) == 0 79 | and getattr(module, "weight", None) is not None 80 | and module.weight.requires_grad 81 | ): 82 | return True 83 | return False 84 | 85 | lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) 86 | transformer_wrap_policy = functools.partial( 87 | transformer_auto_wrap_policy, 88 | transformer_layer_cls=( 89 | PrefixEncoder, 90 | PromptEncoder, 91 | PromptEmbedding, 92 | FullyShardedDataParallelPlugin.get_module_class_from_name( 93 | model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "") 94 | ), 95 | ), 96 | ) 97 | 98 | auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) 99 | return auto_wrap_policy 100 | 101 | 102 | def transpose(weight, fan_in_fan_out): 103 | return weight.T if fan_in_fan_out else weight 104 | -------------------------------------------------------------------------------- /llava/peft/utils/save_and_load.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .config import PeftType 17 | 18 | 19 | def get_peft_model_state_dict(model, state_dict=None): 20 | """ 21 | Get the state dict of the Peft model. 22 | 23 | Args: 24 | model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP, 25 | the model should be the underlying model/unwrapped model (i.e. model.module). 26 | state_dict (`dict`, *optional*, defaults to `None`): 27 | The state dict of the model. If not provided, the state dict of the model 28 | will be used. 29 | """ 30 | if state_dict is None: 31 | state_dict = model.state_dict() 32 | if model.peft_config.peft_type == PeftType.LORA: 33 | # to_return = lora_state_dict(model, bias=model.peft_config.bias) 34 | # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` 35 | # to directly with the state dict which is necessary when using DeepSpeed or FSDP 36 | bias = model.peft_config.bias 37 | if bias == "none": 38 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k} 39 | elif bias == "all": 40 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k} 41 | elif bias == "lora_only": 42 | to_return = {} 43 | for k in state_dict: 44 | if "lora_" in k: 45 | to_return[k] = state_dict[k] 46 | bias_name = k.split("lora_")[0] + "bias" 47 | if bias_name in state_dict: 48 | to_return[bias_name] = state_dict[bias_name] 49 | else: 50 | raise NotImplementedError 51 | else: 52 | to_return = {} 53 | if model.peft_config.inference_mode: 54 | prompt_embeddings = model.prompt_encoder.embedding.weight 55 | else: 56 | prompt_embeddings = model.get_prompt_embedding_to_save() 57 | to_return["prompt_embeddings"] = prompt_embeddings 58 | if model.modules_to_save is not None: 59 | for key, value in state_dict.items(): 60 | if any(module_name in key for module_name in model.modules_to_save): 61 | to_return[key] = value 62 | return to_return 63 | 64 | 65 | def set_peft_model_state_dict(model, peft_model_state_dict): 66 | """ 67 | Set the state dict of the Peft model. 68 | 69 | Args: 70 | model ([`PeftModel`]): The Peft model. 71 | peft_model_state_dict (`dict`): The state dict of the Peft model. 72 | """ 73 | 74 | for name, param in model.named_parameters(): 75 | if name in peft_model_state_dict.keys(): 76 | print(f"Loading LoRA in lora_path, {name}...") 77 | 78 | model.load_state_dict(peft_model_state_dict, strict=False) 79 | return model 80 | -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | accelerate==0.27.0 3 | bitsandbytes==0.41.0 4 | deepspeed==0.9.5 5 | einops-exts==0.0.4 6 | einops==0.6.1 7 | gradio==3.35.2 8 | gradio_client==0.2.9 9 | httpx==0.24.0 10 | markdown2==2.4.10 11 | numpy==1.26.0 12 | peft==0.4.0 13 | scikit-learn==1.2.2 14 | sentencepiece==0.1.99 15 | shortuuid==1.0.11 16 | timm==0.6.13 17 | torchvision==0.15.2 18 | transformers==4.41 19 | wandb==0.15.12 20 | wavedrom==2.0.3.post3 21 | Pygments==2.16.1 22 | omegaconf 23 | pytorch_lightning==2.1 24 | scikit-image 25 | opencv-python 26 | lpips -------------------------------------------------------------------------------- /requirements_qwen_2_5.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.0 2 | accelerate==0.27.0 3 | bitsandbytes==0.41.0 4 | deepspeed==0.14.4 5 | einops-exts==0.0.4 6 | einops==0.6.1 7 | gradio==3.35.2 8 | gradio_client==0.2.9 9 | httpx==0.24.0 10 | markdown2==2.4.10 11 | numpy==1.26.0 12 | peft==0.4.0 13 | scikit-learn==1.2.2 14 | sentencepiece==0.1.99 15 | shortuuid==1.0.11 16 | timm==0.6.13 17 | torchvision==0.19.0 18 | transformers==4.49.0 19 | wandb==0.15.12 20 | wavedrom==2.0.3.post3 21 | Pygments==2.16.1 -------------------------------------------------------------------------------- /scripts/convert_gqa_for_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--src", type=str) 7 | parser.add_argument("--dst", type=str) 8 | args = parser.parse_args() 9 | 10 | all_answers = [] 11 | for line_idx, line in enumerate(open(args.src)): 12 | res = json.loads(line) 13 | question_id = res['question_id'] 14 | text = res['text'].rstrip('.').lower() 15 | all_answers.append({"questionId": question_id, "prediction": text}) 16 | 17 | with open(args.dst, 'w') as f: 18 | json.dump(all_answers, f) 19 | -------------------------------------------------------------------------------- /scripts/convert_mmbench_for_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import pandas as pd 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--annotation-file", type=str, required=True) 9 | parser.add_argument("--result-dir", type=str, required=True) 10 | parser.add_argument("--upload-dir", type=str, required=True) 11 | parser.add_argument("--experiment", type=str, required=True) 12 | 13 | return parser.parse_args() 14 | 15 | if __name__ == "__main__": 16 | args = get_args() 17 | 18 | df = pd.read_table(args.annotation_file) 19 | 20 | cur_df = df.copy() 21 | cur_df = cur_df.drop(columns=['hint', 'category', 'source', 'image', 'comment', 'l2-category']) 22 | cur_df.insert(6, 'prediction', None) 23 | for pred in open(os.path.join(args.result_dir, f"{args.experiment}.jsonl")): 24 | pred = json.loads(pred) 25 | cur_df.loc[df['index'] == pred['question_id'], 'prediction'] = pred['text'] 26 | 27 | cur_df.to_excel(os.path.join(args.upload_dir, f"{args.experiment}.xlsx"), index=False, engine='openpyxl') 28 | -------------------------------------------------------------------------------- /scripts/convert_mmvet_for_eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--src", type=str) 7 | parser.add_argument("--dst", type=str) 8 | args = parser.parse_args() 9 | 10 | cur_result = {} 11 | 12 | for line in open(args.src): 13 | data = json.loads(line) 14 | qid = data['question_id'] 15 | cur_result[f'v1_{qid}'] = data['text'] 16 | 17 | with open(args.dst, 'w') as f: 18 | json.dump(cur_result, f, indent=2) 19 | -------------------------------------------------------------------------------- /scripts/convert_seed_for_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | 6 | def get_args(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--annotation-file", type=str) 9 | parser.add_argument("--result-file", type=str) 10 | parser.add_argument("--result-upload-file", type=str) 11 | return parser.parse_args() 12 | 13 | 14 | def eval_single(result_file, eval_only_type=None): 15 | results = {} 16 | for line in open(result_file): 17 | row = json.loads(line) 18 | results[row['question_id']] = row 19 | 20 | type_counts = {} 21 | correct_counts = {} 22 | for question_data in data['questions']: 23 | if eval_only_type is not None and question_data['data_type'] != eval_only_type: continue 24 | data_type = question_data['question_type_id'] 25 | type_counts[data_type] = type_counts.get(data_type, 0) + 1 26 | try: 27 | question_id = int(question_data['question_id']) 28 | except: 29 | question_id = question_data['question_id'] 30 | if question_id not in results: 31 | correct_counts[data_type] = correct_counts.get(data_type, 0) 32 | continue 33 | row = results[question_id] 34 | if row['text'] == question_data['answer']: 35 | correct_counts[data_type] = correct_counts.get(data_type, 0) + 1 36 | 37 | total_count = 0 38 | total_correct = 0 39 | for data_type in sorted(type_counts.keys()): 40 | accuracy = correct_counts[data_type] / type_counts[data_type] * 100 41 | if eval_only_type is None: 42 | print(f"{ques_type_id_to_name[data_type]}: {accuracy:.2f}%") 43 | 44 | total_count += type_counts[data_type] 45 | total_correct += correct_counts[data_type] 46 | 47 | total_accuracy = total_correct / total_count * 100 48 | if eval_only_type is None: 49 | print(f"Total accuracy: {total_accuracy:.2f}%") 50 | else: 51 | print(f"{eval_only_type} accuracy: {total_accuracy:.2f}%") 52 | 53 | return results 54 | 55 | if __name__ == "__main__": 56 | args = get_args() 57 | data = json.load(open(args.annotation_file)) 58 | ques_type_id_to_name = {id:n for n,id in data['question_type'].items()} 59 | 60 | results = eval_single(args.result_file) 61 | eval_single(args.result_file, eval_only_type='image') 62 | eval_single(args.result_file, eval_only_type='video') 63 | 64 | with open(args.result_upload_file, 'w') as fp: 65 | for question in data['questions']: 66 | qid = question['question_id'] 67 | if qid in results: 68 | result = results[qid] 69 | else: 70 | result = results[int(qid)] 71 | fp.write(json.dumps({ 72 | 'question_id': qid, 73 | 'prediction': result['text'] 74 | }) + '\n') 75 | -------------------------------------------------------------------------------- /scripts/convert_sqa_to_llava.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import fire 4 | import re 5 | from convert_sqa_to_llava_base_prompt import build_prompt_chatbot 6 | 7 | 8 | def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"): 9 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 10 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 11 | 12 | split_problems = build_prompt_chatbot( 13 | problems, split_indices, prompt_format, 14 | use_caption=False, is_test=False) 15 | 16 | target_format = [] 17 | for prob_id, (input, output) in split_problems.items(): 18 | if input.startswith('Question: '): 19 | input = input.replace('Question: ', '') 20 | if output.startswith('Answer: '): 21 | output = output.replace('Answer: ', '') 22 | 23 | raw_prob_data = problems[prob_id] 24 | if raw_prob_data['image'] is None: 25 | target_format.append({ 26 | "id": prob_id, 27 | "conversations": [ 28 | {'from': 'human', 'value': f"{input}"}, 29 | {'from': 'gpt', 'value': f"{output}"}, 30 | ], 31 | }) 32 | 33 | else: 34 | target_format.append({ 35 | "id": prob_id, 36 | "image": os.path.join(prob_id, raw_prob_data['image']), 37 | "conversations": [ 38 | {'from': 'human', 'value': f"{input}\n"}, 39 | {'from': 'gpt', 'value': f"{output}"}, 40 | ], 41 | }) 42 | 43 | print(f'Number of samples: {len(target_format)}') 44 | 45 | with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f: 46 | json.dump(target_format, f, indent=2) 47 | 48 | 49 | def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"): 50 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split] 51 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 52 | 53 | split_problems = build_prompt_chatbot( 54 | problems, split_indices, prompt_format, 55 | use_caption=False, is_test=False) 56 | 57 | writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w") 58 | for prob_id, (input, output) in split_problems.items(): 59 | if input.startswith('Question: '): 60 | input = input.replace('Question: ', '') 61 | if output.startswith('Answer: '): 62 | output = output.replace('Answer: ', '') 63 | 64 | raw_prob_data = problems[prob_id] 65 | if raw_prob_data['image'] is None: 66 | data = { 67 | "id": prob_id, 68 | "instruction": f"{input}", 69 | "output": f"{output}", 70 | } 71 | 72 | else: 73 | data = { 74 | "id": prob_id, 75 | "image": os.path.join(prob_id, raw_prob_data['image']), 76 | "instruction": f"{input}\n", 77 | "output": f"{output}", 78 | } 79 | writer.write(json.dumps(data) + '\n') 80 | writer.close() 81 | 82 | 83 | def main(task, **kwargs): 84 | globals()[task](**kwargs) 85 | 86 | 87 | if __name__ == "__main__": 88 | fire.Fire(main) 89 | -------------------------------------------------------------------------------- /scripts/convert_vizwiz_for_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | from llava.eval.m4c_evaluator import EvalAIAnswerProcessor 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--annotation-file', type=str, required=True) 11 | parser.add_argument('--result-file', type=str, required=True) 12 | parser.add_argument('--result-upload-file', type=str, required=True) 13 | return parser.parse_args() 14 | 15 | 16 | if __name__ == '__main__': 17 | 18 | args = parse_args() 19 | 20 | os.makedirs(os.path.dirname(args.result_upload_file), exist_ok=True) 21 | 22 | results = [] 23 | error_line = 0 24 | for line_idx, line in enumerate(open(args.result_file)): 25 | try: 26 | results.append(json.loads(line)) 27 | except: 28 | error_line += 1 29 | results = {x['question_id']: x['text'] for x in results} 30 | test_split = [json.loads(line) for line in open(args.annotation_file)] 31 | split_ids = set([x['question_id'] for x in test_split]) 32 | 33 | print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}') 34 | 35 | all_answers = [] 36 | 37 | answer_processor = EvalAIAnswerProcessor() 38 | 39 | for x in test_split: 40 | assert x['question_id'] in results 41 | all_answers.append({ 42 | 'image': x['image'], 43 | 'answer': answer_processor(results[x['question_id']]) 44 | }) 45 | 46 | with open(args.result_upload_file, 'w') as f: 47 | json.dump(all_answers, f) 48 | -------------------------------------------------------------------------------- /scripts/convert_vqav2_for_submission.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | 5 | from llava.eval.m4c_evaluator import EvalAIAnswerProcessor 6 | 7 | 8 | def parse_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--dir', type=str, default="./playground/data/eval/vqav2") 11 | parser.add_argument('--ckpt', type=str, required=True) 12 | parser.add_argument('--split', type=str, required=True) 13 | return parser.parse_args() 14 | 15 | 16 | if __name__ == '__main__': 17 | 18 | args = parse_args() 19 | 20 | src = os.path.join(args.dir, 'answers', args.split, args.ckpt, 'merge.jsonl') 21 | test_split = os.path.join(args.dir, 'llava_vqav2_mscoco_test2015.jsonl') 22 | dst = os.path.join(args.dir, 'answers_upload', args.split, f'{args.ckpt}.json') 23 | os.makedirs(os.path.dirname(dst), exist_ok=True) 24 | 25 | results = [] 26 | error_line = 0 27 | for line_idx, line in enumerate(open(src)): 28 | try: 29 | results.append(json.loads(line)) 30 | except: 31 | error_line += 1 32 | 33 | results = {x['question_id']: x['text'] for x in results} 34 | test_split = [json.loads(line) for line in open(test_split)] 35 | split_ids = set([x['question_id'] for x in test_split]) 36 | 37 | print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}') 38 | 39 | all_answers = [] 40 | 41 | answer_processor = EvalAIAnswerProcessor() 42 | 43 | for x in test_split: 44 | if x['question_id'] not in results: 45 | all_answers.append({ 46 | 'question_id': x['question_id'], 47 | 'answer': '' 48 | }) 49 | else: 50 | all_answers.append({ 51 | 'question_id': x['question_id'], 52 | 'answer': answer_processor(results[x['question_id']]) 53 | }) 54 | 55 | with open(dst, 'w') as f: 56 | json.dump(all_answers, open(dst, 'w')) 57 | -------------------------------------------------------------------------------- /scripts/extract_mm_projector.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is just a utility that I use to extract the projector for quantized models. 3 | It is NOT necessary at all to train, or run inference/serve demos. 4 | Use this script ONLY if you fully understand its implications. 5 | """ 6 | 7 | 8 | import os 9 | import argparse 10 | import torch 11 | import json 12 | from collections import defaultdict 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='Extract MMProjector weights') 17 | parser.add_argument('--model-path', type=str, help='model folder') 18 | parser.add_argument('--output', type=str, help='output file') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | if __name__ == '__main__': 24 | args = parse_args() 25 | 26 | keys_to_match = ['mm_projector'] 27 | ckpt_to_key = defaultdict(list) 28 | try: 29 | model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json'))) 30 | for k, v in model_indices['weight_map'].items(): 31 | if any(key_match in k for key_match in keys_to_match): 32 | ckpt_to_key[v].append(k) 33 | except FileNotFoundError: 34 | # Smaller models or model checkpoints saved by DeepSpeed. 35 | v = 'pytorch_model.bin' 36 | for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys(): 37 | if any(key_match in k for key_match in keys_to_match): 38 | ckpt_to_key[v].append(k) 39 | 40 | loaded_weights = {} 41 | 42 | for ckpt_name, weight_keys in ckpt_to_key.items(): 43 | ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu') 44 | for k in weight_keys: 45 | loaded_weights[k] = ckpt[k] 46 | 47 | torch.save(loaded_weights, args.output) 48 | -------------------------------------------------------------------------------- /scripts/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5! 4 | 5 | # Uncomment and set the following variables correspondingly to run this script: 6 | 7 | ################## VICUNA ################## 8 | # PROMPT_VERSION=v1 9 | # MODEL_VERSION="vicuna-v1-3-7b" 10 | ################## VICUNA ################## 11 | 12 | ################## LLaMA-2 ################## 13 | # PROMPT_VERSION="llava_llama_2" 14 | # MODEL_VERSION="llama-2-7b-chat" 15 | ################## LLaMA-2 ################## 16 | 17 | deepspeed llava/train/train_mem.py \ 18 | --deepspeed ./scripts/zero2.json \ 19 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 20 | --version $PROMPT_VERSION \ 21 | --data_path ./playground/data/llava_instruct_80k.json \ 22 | --image_folder /path/to/coco/train2017 \ 23 | --vision_tower openai/clip-vit-large-patch14 \ 24 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ 25 | --mm_vision_select_layer -2 \ 26 | --mm_use_im_start_end False \ 27 | --mm_use_im_patch_token False \ 28 | --bf16 True \ 29 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \ 30 | --num_train_epochs 1 \ 31 | --per_device_train_batch_size 16 \ 32 | --per_device_eval_batch_size 4 \ 33 | --gradient_accumulation_steps 1 \ 34 | --evaluation_strategy "no" \ 35 | --save_strategy "steps" \ 36 | --save_steps 50000 \ 37 | --save_total_limit 1 \ 38 | --learning_rate 2e-5 \ 39 | --weight_decay 0. \ 40 | --warmup_ratio 0.03 \ 41 | --lr_scheduler_type "cosine" \ 42 | --logging_steps 1 \ 43 | --tf32 True \ 44 | --model_max_length 2048 \ 45 | --gradient_checkpointing True \ 46 | --dataloader_num_workers 4 \ 47 | --lazy_preprocess True \ 48 | --report_to wandb 49 | -------------------------------------------------------------------------------- /scripts/finetune_full_schedule.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5! 4 | 5 | # Uncomment and set the following variables correspondingly to run this script: 6 | 7 | ################## VICUNA ################## 8 | # PROMPT_VERSION=v1 9 | # MODEL_VERSION="vicuna-v1-3-7b" 10 | ################## VICUNA ################## 11 | 12 | ################## LLaMA-2 ################## 13 | # PROMPT_VERSION="llava_llama_2" 14 | # MODEL_VERSION="llama-2-7b-chat" 15 | ################## LLaMA-2 ################## 16 | 17 | deepspeed llava/train/train_mem.py \ 18 | --deepspeed ./scripts/zero2.json \ 19 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 20 | --version $PROMPT_VERSION \ 21 | --data_path ./playground/data/llava_instruct_158k.json \ 22 | --image_folder /path/to/coco/train2017 \ 23 | --vision_tower openai/clip-vit-large-patch14 \ 24 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ 25 | --mm_vision_select_layer -2 \ 26 | --mm_use_im_start_end False \ 27 | --mm_use_im_patch_token False \ 28 | --bf16 True \ 29 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \ 30 | --num_train_epochs 3 \ 31 | --per_device_train_batch_size 16 \ 32 | --per_device_eval_batch_size 4 \ 33 | --gradient_accumulation_steps 1 \ 34 | --evaluation_strategy "no" \ 35 | --save_strategy "steps" \ 36 | --save_steps 50000 \ 37 | --save_total_limit 1 \ 38 | --learning_rate 2e-5 \ 39 | --weight_decay 0. \ 40 | --warmup_ratio 0.03 \ 41 | --lr_scheduler_type "cosine" \ 42 | --logging_steps 1 \ 43 | --tf32 True \ 44 | --model_max_length 2048 \ 45 | --gradient_checkpointing True \ 46 | --dataloader_num_workers 4 \ 47 | --lazy_preprocess True \ 48 | --report_to wandb 49 | -------------------------------------------------------------------------------- /scripts/finetune_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5! 4 | 5 | # Uncomment and set the following variables correspondingly to run this script: 6 | 7 | ################## VICUNA ################## 8 | # PROMPT_VERSION=v1 9 | # MODEL_VERSION="vicuna-v1-3-7b" 10 | ################## VICUNA ################## 11 | 12 | ################## LLaMA-2 ################## 13 | # PROMPT_VERSION="llava_llama_2" 14 | # MODEL_VERSION="llama-2-7b-chat" 15 | ################## LLaMA-2 ################## 16 | 17 | deepspeed llava/train/train_mem.py \ 18 | --deepspeed ./scripts/zero2.json \ 19 | --lora_enable True \ 20 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 21 | --version $PROMPT_VERSION \ 22 | --data_path ./playground/data/llava_instruct_80k.json \ 23 | --image_folder /path/to/coco/train2017 \ 24 | --vision_tower openai/clip-vit-large-patch14 \ 25 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ 26 | --mm_vision_select_layer -2 \ 27 | --mm_use_im_start_end False \ 28 | --mm_use_im_patch_token False \ 29 | --bf16 True \ 30 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \ 31 | --num_train_epochs 1 \ 32 | --per_device_train_batch_size 16 \ 33 | --per_device_eval_batch_size 4 \ 34 | --gradient_accumulation_steps 1 \ 35 | --evaluation_strategy "no" \ 36 | --save_strategy "steps" \ 37 | --save_steps 50000 \ 38 | --save_total_limit 1 \ 39 | --learning_rate 2e-5 \ 40 | --weight_decay 0. \ 41 | --warmup_ratio 0.03 \ 42 | --lr_scheduler_type "cosine" \ 43 | --logging_steps 1 \ 44 | --tf32 True \ 45 | --model_max_length 2048 \ 46 | --gradient_checkpointing True \ 47 | --lazy_preprocess True \ 48 | --dataloader_num_workers 4 \ 49 | --report_to wandb 50 | -------------------------------------------------------------------------------- /scripts/finetune_qlora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5! 4 | 5 | # Uncomment and set the following variables correspondingly to run this script: 6 | 7 | ################## VICUNA ################## 8 | # PROMPT_VERSION=v1 9 | # MODEL_VERSION="vicuna-v1-3-7b" 10 | ################## VICUNA ################## 11 | 12 | ################## LLaMA-2 ################## 13 | # PROMPT_VERSION="llava_llama_2" 14 | # MODEL_VERSION="llama-2-7b-chat" 15 | ################## LLaMA-2 ################## 16 | 17 | deepspeed llava/train/train_mem.py \ 18 | --deepspeed ./scripts/zero2.json \ 19 | --lora_enable True \ 20 | --bits 4 \ 21 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 22 | --version $PROMPT_VERSION \ 23 | --data_path ./playground/data/llava_instruct_80k.json \ 24 | --image_folder /path/to/coco/train2017 \ 25 | --vision_tower openai/clip-vit-large-patch14 \ 26 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \ 27 | --mm_vision_select_layer -2 \ 28 | --mm_use_im_start_end False \ 29 | --mm_use_im_patch_token False \ 30 | --bf16 True \ 31 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \ 32 | --num_train_epochs 1 \ 33 | --per_device_train_batch_size 16 \ 34 | --per_device_eval_batch_size 4 \ 35 | --gradient_accumulation_steps 1 \ 36 | --evaluation_strategy "no" \ 37 | --save_strategy "steps" \ 38 | --save_steps 50000 \ 39 | --save_total_limit 1 \ 40 | --learning_rate 2e-5 \ 41 | --weight_decay 0. \ 42 | --warmup_ratio 0.03 \ 43 | --lr_scheduler_type "cosine" \ 44 | --logging_steps 1 \ 45 | --tf32 True \ 46 | --model_max_length 2048 \ 47 | --gradient_checkpointing True \ 48 | --lazy_preprocess True \ 49 | --dataloader_num_workers 4 \ 50 | --report_to wandb 51 | -------------------------------------------------------------------------------- /scripts/finetune_sqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5! 4 | 5 | deepspeed llava/train/train_mem.py \ 6 | --deepspeed ./scripts/zero2.json \ 7 | --model_name_or_path lmsys/vicuna-13b-v1.3 \ 8 | --version $PROMPT_VERSION \ 9 | --data_path /Data/ScienceQA/data/scienceqa/llava_train_QCM-LEA.json \ 10 | --image_folder /Data/ScienceQA/data/scienceqa/images/train \ 11 | --vision_tower openai/clip-vit-large-patch14 \ 12 | --pretrain_mm_mlp_adapter ./checkpoints/huggingface/liuhaotian/llava-pretrain-vicuna-13b-v1.3/mm_projector.bin \ 13 | --mm_vision_select_layer -2 \ 14 | --mm_use_im_start_end False \ 15 | --mm_use_im_patch_token False \ 16 | --bf16 True \ 17 | --output_dir ./checkpoints/llava-vicuna-13b-v1.3-pretrain_lcs558k_plain-ScienceQA_QCM_LEA-12e \ 18 | --num_train_epochs 12 \ 19 | --per_device_train_batch_size 16 \ 20 | --per_device_eval_batch_size 4 \ 21 | --gradient_accumulation_steps 1 \ 22 | --evaluation_strategy "no" \ 23 | --save_strategy "steps" \ 24 | --save_steps 50000 \ 25 | --save_total_limit 1 \ 26 | --learning_rate 2e-5 \ 27 | --weight_decay 0. \ 28 | --warmup_ratio 0.03 \ 29 | --lr_scheduler_type "cosine" \ 30 | --logging_steps 1 \ 31 | --tf32 True \ 32 | --model_max_length 2048 \ 33 | --gradient_checkpointing True \ 34 | --dataloader_num_workers 4 \ 35 | --lazy_preprocess True \ 36 | --report_to wandb 37 | -------------------------------------------------------------------------------- /scripts/merge_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from llava.model.builder import load_pretrained_model 3 | from llava.mm_utils import get_model_name_from_path 4 | 5 | 6 | def merge_lora(args): 7 | model_name = get_model_name_from_path(args.model_path) 8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu') 9 | 10 | model.save_pretrained(args.save_model_path) 11 | tokenizer.save_pretrained(args.save_model_path) 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model-path", type=str, required=True) 17 | parser.add_argument("--model-base", type=str, required=True) 18 | parser.add_argument("--save-model-path", type=str, required=True) 19 | 20 | args = parser.parse_args() 21 | 22 | merge_lora(args) 23 | -------------------------------------------------------------------------------- /scripts/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5! 4 | 5 | # Uncomment and set the following variables correspondingly to run this script: 6 | 7 | # MODEL_VERSION=vicuna-v1-3-7b 8 | # MODEL_VERSION=llama-2-7b-chat 9 | 10 | ########### DO NOT CHANGE ########### 11 | ########### USE THIS FOR BOTH ########### 12 | PROMPT_VERSION=plain 13 | ########### DO NOT CHANGE ########### 14 | 15 | deepspeed llava/train/train_mem.py \ 16 | --deepspeed ./scripts/zero2.json \ 17 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 18 | --version $PROMPT_VERSION \ 19 | --data_path /path/to/pretrain_data.json \ 20 | --image_folder /path/to/images \ 21 | --vision_tower openai/clip-vit-large-patch14 \ 22 | --tune_mm_mlp_adapter True \ 23 | --mm_vision_select_layer -2 \ 24 | --mm_use_im_start_end False \ 25 | --mm_use_im_patch_token False \ 26 | --bf16 True \ 27 | --output_dir ./checkpoints/llava-$MODEL_VERSION-pretrain \ 28 | --num_train_epochs 1 \ 29 | --per_device_train_batch_size 16 \ 30 | --per_device_eval_batch_size 4 \ 31 | --gradient_accumulation_steps 1 \ 32 | --evaluation_strategy "no" \ 33 | --save_strategy "steps" \ 34 | --save_steps 24000 \ 35 | --save_total_limit 1 \ 36 | --learning_rate 2e-3 \ 37 | --weight_decay 0. \ 38 | --warmup_ratio 0.03 \ 39 | --lr_scheduler_type "cosine" \ 40 | --logging_steps 1 \ 41 | --tf32 True \ 42 | --model_max_length 2048 \ 43 | --gradient_checkpointing True \ 44 | --dataloader_num_workers 4 \ 45 | --lazy_preprocess True \ 46 | --report_to wandb 47 | -------------------------------------------------------------------------------- /scripts/pretrain_xformers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Uncomment and set the following variables correspondingly to run this script: 4 | 5 | # MODEL_VERSION=vicuna-v1-3-7b 6 | # MODEL_VERSION=llama-2-7b-chat 7 | 8 | ########### DO NOT CHANGE ########### 9 | ########### USE THIS FOR BOTH ########### 10 | PROMPT_VERSION=plain 11 | ########### DO NOT CHANGE ########### 12 | 13 | deepspeed llava/train/train_xformers.py \ 14 | --deepspeed ./scripts/zero2.json \ 15 | --model_name_or_path ./checkpoints/$MODEL_VERSION \ 16 | --version $PROMPT_VERSION \ 17 | --data_path /path/to/pretrain_data.json \ 18 | --image_folder /path/to/images \ 19 | --vision_tower openai/clip-vit-large-patch14 \ 20 | --tune_mm_mlp_adapter True \ 21 | --mm_vision_select_layer -2 \ 22 | --mm_use_im_start_end False \ 23 | --mm_use_im_patch_token False \ 24 | --bf16 False \ 25 | --output_dir ./checkpoints/llava-$MODEL_VERSION-pretrain \ 26 | --num_train_epochs 1 \ 27 | --per_device_train_batch_size 4 \ 28 | --per_device_eval_batch_size 4 \ 29 | --gradient_accumulation_steps 4 \ 30 | --evaluation_strategy "no" \ 31 | --save_strategy "steps" \ 32 | --save_steps 24000 \ 33 | --save_total_limit 1 \ 34 | --learning_rate 2e-3 \ 35 | --weight_decay 0. \ 36 | --warmup_ratio 0.03 \ 37 | --lr_scheduler_type "cosine" \ 38 | --logging_steps 1 \ 39 | --tf32 False \ 40 | --model_max_length 2048 \ 41 | --gradient_checkpointing True \ 42 | --dataloader_num_workers 4 \ 43 | --lazy_preprocess True \ 44 | --report_to wandb 45 | -------------------------------------------------------------------------------- /scripts/sqa_eval_batch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHUNKS=8 4 | for IDX in {0..7}; do 5 | CUDA_VISIBLE_DEVICES=$IDX python -m llava.eval.model_vqa_science \ 6 | --model-path liuhaotian/llava-lcs558k-scienceqa-vicuna-13b-v1.3 \ 7 | --question-file ~/haotian/datasets/ScienceQA/data/scienceqa/llava_test_QCM-LEA.json \ 8 | --image-folder ~/haotian/datasets/ScienceQA/data/scienceqa/images/test \ 9 | --answers-file ./test_llava-13b-chunk$CHUNKS_$IDX.jsonl \ 10 | --num-chunks $CHUNKS \ 11 | --chunk-idx $IDX \ 12 | --conv-mode llava_v1 & 13 | done 14 | -------------------------------------------------------------------------------- /scripts/sqa_eval_gather.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CHUNKS=8 4 | output_file="test_llava-13b.jsonl" 5 | 6 | # Clear out the output file if it exists. 7 | > "$output_file" 8 | 9 | # Loop through the indices and concatenate each file. 10 | for idx in $(seq 0 $((CHUNKS-1))); do 11 | cat "./test_llava-13b-chunk${idx}.jsonl" >> "$output_file" 12 | done 13 | 14 | python llava/eval/eval_science_qa.py \ 15 | --base-dir ~/haotian/datasets/ScienceQA/data/scienceqa \ 16 | --result-file ./test_llava-13b.jsonl \ 17 | --output-file ./test_llava-13b_output.json \ 18 | --output-result ./test_llava-13b_result.json 19 | -------------------------------------------------------------------------------- /scripts/upload_pypi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Step 0: Clean up 4 | rm -rf dist 5 | 6 | # Step 1: Change the package name to "llava-torch" 7 | sed -i 's/name = "llava"/name = "llava-torch"/' pyproject.toml 8 | 9 | # Step 2: Build the package 10 | python -m build 11 | 12 | # Step 3: Revert the changes in pyproject.toml to the original 13 | sed -i 's/name = "llava-torch"/name = "llava"/' pyproject.toml 14 | 15 | # Step 4: Upload to PyPI 16 | python -m twine upload dist/* 17 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/gqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 4 | IFS=',' read -ra GPULIST <<< "$gpu_list" 5 | 6 | CHUNKS=${#GPULIST[@]} 7 | 8 | CKPT="llava-v1.5-13b" 9 | SPLIT="llava_gqa_testdev_balanced" 10 | GQADIR="./playground/data/eval/gqa/data" 11 | 12 | for IDX in $(seq 0 $((CHUNKS-1))); do 13 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \ 14 | --model-path liuhaotian/llava-v1.5-13b \ 15 | --question-file ./playground/data/eval/gqa/$SPLIT.jsonl \ 16 | --image-folder ./playground/data/eval/gqa/data/images \ 17 | --answers-file ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \ 18 | --num-chunks $CHUNKS \ 19 | --chunk-idx $IDX \ 20 | --temperature 0 \ 21 | --conv-mode vicuna_v1 & 22 | done 23 | 24 | wait 25 | 26 | output_file=./playground/data/eval/gqa/answers/$SPLIT/$CKPT/merge.jsonl 27 | 28 | # Clear out the output file if it exists. 29 | > "$output_file" 30 | 31 | # Loop through the indices and concatenate each file. 32 | for IDX in $(seq 0 $((CHUNKS-1))); do 33 | cat ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 34 | done 35 | 36 | python scripts/convert_gqa_for_eval.py --src $output_file --dst $GQADIR/testdev_balanced_predictions.json 37 | 38 | cd $GQADIR 39 | python eval/eval.py --tier testdev_balanced 40 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/llavabench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m llava.eval.model_vqa \ 4 | --model-path liuhaotian/llava-v1.5-13b \ 5 | --question-file ./playground/data/eval/llava-bench-in-the-wild/questions.jsonl \ 6 | --image-folder ./playground/data/eval/llava-bench-in-the-wild/images \ 7 | --answers-file ./playground/data/eval/llava-bench-in-the-wild/answers/llava-v1.5-13b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode vicuna_v1 10 | 11 | mkdir -p playground/data/eval/llava-bench-in-the-wild/reviews 12 | 13 | python llava/eval/eval_gpt_review_bench.py \ 14 | --question playground/data/eval/llava-bench-in-the-wild/questions.jsonl \ 15 | --context playground/data/eval/llava-bench-in-the-wild/context.jsonl \ 16 | --rule llava/eval/table/rule.json \ 17 | --answer-list \ 18 | playground/data/eval/llava-bench-in-the-wild/answers_gpt4.jsonl \ 19 | playground/data/eval/llava-bench-in-the-wild/answers/llava-v1.5-13b.jsonl \ 20 | --output \ 21 | playground/data/eval/llava-bench-in-the-wild/reviews/llava-v1.5-13b.jsonl 22 | 23 | python llava/eval/summarize_gpt_review.py -f playground/data/eval/llava-bench-in-the-wild/reviews/llava-v1.5-13b.jsonl 24 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/mmbench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SPLIT="mmbench_dev_20230712" 4 | 5 | python -m llava.eval.model_vqa_mmbench \ 6 | --model-path liuhaotian/llava-v1.5-13b \ 7 | --question-file ./playground/data/eval/mmbench/$SPLIT.tsv \ 8 | --answers-file ./playground/data/eval/mmbench/answers/$SPLIT/llava-v1.5-13b.jsonl \ 9 | --single-pred-prompt \ 10 | --temperature 0 \ 11 | --conv-mode vicuna_v1 12 | 13 | mkdir -p playground/data/eval/mmbench/answers_upload/$SPLIT 14 | 15 | python scripts/convert_mmbench_for_submission.py \ 16 | --annotation-file ./playground/data/eval/mmbench/$SPLIT.tsv \ 17 | --result-dir ./playground/data/eval/mmbench/answers/$SPLIT \ 18 | --upload-dir ./playground/data/eval/mmbench/answers_upload/$SPLIT \ 19 | --experiment llava-v1.5-13b 20 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/mmbench_cn.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SPLIT="mmbench_dev_cn_20231003" 4 | 5 | python -m llava.eval.model_vqa_mmbench \ 6 | --model-path liuhaotian/llava-v1.5-13b \ 7 | --question-file ./playground/data/eval/mmbench_cn/$SPLIT.tsv \ 8 | --answers-file ./playground/data/eval/mmbench_cn/answers/$SPLIT/llava-v1.5-13b.jsonl \ 9 | --lang cn \ 10 | --single-pred-prompt \ 11 | --temperature 0 \ 12 | --conv-mode vicuna_v1 13 | 14 | mkdir -p playground/data/eval/mmbench/answers_upload/$SPLIT 15 | 16 | python scripts/convert_mmbench_for_submission.py \ 17 | --annotation-file ./playground/data/eval/mmbench_cn/$SPLIT.tsv \ 18 | --result-dir ./playground/data/eval/mmbench_cn/answers/$SPLIT \ 19 | --upload-dir ./playground/data/eval/mmbench_cn/answers_upload/$SPLIT \ 20 | --experiment llava-v1.5-13b 21 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/mme.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m llava.eval.model_vqa_loader \ 4 | --model-path liuhaotian/llava-v1.5-13b \ 5 | --question-file ./playground/data/eval/MME/llava_mme.jsonl \ 6 | --image-folder ./playground/data/eval/MME/MME_Benchmark_release_version \ 7 | --answers-file ./playground/data/eval/MME/answers/llava-v1.5-13b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode vicuna_v1 10 | 11 | cd ./playground/data/eval/MME 12 | 13 | python convert_answer_to_mme.py --experiment llava-v1.5-13b 14 | 15 | cd eval_tool 16 | 17 | python calculation.py --results_dir answers/llava-v1.5-13b 18 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/mmvet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m llava.eval.model_vqa \ 4 | --model-path liuhaotian/llava-v1.5-13b \ 5 | --question-file ./playground/data/eval/mm-vet/llava-mm-vet.jsonl \ 6 | --image-folder ./playground/data/eval/mm-vet/images \ 7 | --answers-file ./playground/data/eval/mm-vet/answers/llava-v1.5-13b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode vicuna_v1 10 | 11 | mkdir -p ./playground/data/eval/mm-vet/results 12 | 13 | python scripts/convert_mmvet_for_eval.py \ 14 | --src ./playground/data/eval/mm-vet/answers/llava-v1.5-13b.jsonl \ 15 | --dst ./playground/data/eval/mm-vet/results/llava-v1.5-13b.json 16 | 17 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/pope.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m llava.eval.model_vqa_loader \ 4 | --model-path liuhaotian/llava-v1.5-13b \ 5 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \ 6 | --image-folder ./playground/data/eval/pope/val2014 \ 7 | --answers-file ./playground/data/eval/pope/answers/llava-v1.5-13b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode vicuna_v1 10 | 11 | python llava/eval/eval_pope.py \ 12 | --annotation-dir ./playground/data/eval/pope/coco \ 13 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \ 14 | --result-file ./playground/data/eval/pope/answers/llava-v1.5-13b.jsonl 15 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/qbench.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$1" = "dev" ]; then 4 | echo "Evaluating in 'dev' split." 5 | elif [ "$1" = "test" ]; then 6 | echo "Evaluating in 'test' split." 7 | else 8 | echo "Unknown split, please choose between 'dev' and 'test'." 9 | exit 1 10 | fi 11 | 12 | python -m llava.eval.model_vqa_qbench \ 13 | --model-path liuhaotian/llava-v1.5-13b \ 14 | --image-folder ./playground/data/eval/qbench/images_llvisionqa/ \ 15 | --questions-file ./playground/data/eval/qbench/llvisionqa_$1.json \ 16 | --answers-file ./playground/data/eval/qbench/llvisionqa_$1_answers.jsonl \ 17 | --conv-mode llava_v1 \ 18 | --lang en 19 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/qbench_zh.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [ "$1" = "dev" ]; then 4 | ZH_SPLIT="验证集" 5 | echo "Evaluating in 'dev' split." 6 | elif [ "$1" = "test" ]; then 7 | ZH_SPLIT="测试集" 8 | echo "Evaluating in 'test' split." 9 | else 10 | echo "Unknown split, please choose between 'dev' and 'test'." 11 | exit 1 12 | fi 13 | 14 | python -m llava.eval.model_vqa_qbench \ 15 | --model-path liuhaotian/llava-v1.5-13b \ 16 | --image-folder ./playground/data/eval/qbench/images_llvisionqa/ \ 17 | --questions-file ./playground/data/eval/qbench/质衡-问答-$ZH_SPLIT.json \ 18 | --answers-file ./playground/data/eval/qbench/llvisionqa_zh_$1_answers.jsonl \ 19 | --conv-mode llava_v1 \ 20 | --lang zh 21 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/seed.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 4 | IFS=',' read -ra GPULIST <<< "$gpu_list" 5 | 6 | CHUNKS=${#GPULIST[@]} 7 | 8 | CKPT="llava-v1.5-13b" 9 | 10 | for IDX in $(seq 0 $((CHUNKS-1))); do 11 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \ 12 | --model-path liuhaotian/llava-v1.5-13b \ 13 | --question-file ./playground/data/eval/seed_bench/llava-seed-bench.jsonl \ 14 | --image-folder ./playground/data/eval/seed_bench \ 15 | --answers-file ./playground/data/eval/seed_bench/answers/$CKPT/${CHUNKS}_${IDX}.jsonl \ 16 | --num-chunks $CHUNKS \ 17 | --chunk-idx $IDX \ 18 | --temperature 0 \ 19 | --conv-mode vicuna_v1 & 20 | done 21 | 22 | wait 23 | 24 | output_file=./playground/data/eval/seed_bench/answers/$CKPT/merge.jsonl 25 | 26 | # Clear out the output file if it exists. 27 | > "$output_file" 28 | 29 | # Loop through the indices and concatenate each file. 30 | for IDX in $(seq 0 $((CHUNKS-1))); do 31 | cat ./playground/data/eval/seed_bench/answers/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 32 | done 33 | 34 | # Evaluate 35 | python scripts/convert_seed_for_submission.py \ 36 | --annotation-file ./playground/data/eval/seed_bench/SEED-Bench.json \ 37 | --result-file $output_file \ 38 | --result-upload-file ./playground/data/eval/seed_bench/answers_upload/llava-v1.5-13b.jsonl 39 | 40 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/sqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m llava.eval.model_vqa_science \ 4 | --model-path liuhaotian/llava-v1.5-13b \ 5 | --question-file ./playground/data/eval/scienceqa/llava_test_CQM-A.json \ 6 | --image-folder ./playground/data/eval/scienceqa/images/test \ 7 | --answers-file ./playground/data/eval/scienceqa/answers/llava-v1.5-13b.jsonl \ 8 | --single-pred-prompt \ 9 | --temperature 0 \ 10 | --conv-mode vicuna_v1 11 | 12 | python llava/eval/eval_science_qa.py \ 13 | --base-dir ./playground/data/eval/scienceqa \ 14 | --result-file ./playground/data/eval/scienceqa/answers/llava-v1.5-13b.jsonl \ 15 | --output-file ./playground/data/eval/scienceqa/answers/llava-v1.5-13b_output.jsonl \ 16 | --output-result ./playground/data/eval/scienceqa/answers/llava-v1.5-13b_result.json 17 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/textvqa.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m llava.eval.model_vqa_loader \ 4 | --model-path liuhaotian/llava-v1.5-13b \ 5 | --question-file ./playground/data/eval/textvqa/llava_textvqa_val_v051_ocr.jsonl \ 6 | --image-folder ./playground/data/eval/textvqa/train_images \ 7 | --answers-file ./playground/data/eval/textvqa/answers/llava-v1.5-13b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode vicuna_v1 10 | 11 | python -m llava.eval.eval_textvqa \ 12 | --annotation-file ./playground/data/eval/textvqa/TextVQA_0.5.1_val.json \ 13 | --result-file ./playground/data/eval/textvqa/answers/llava-v1.5-13b.jsonl 14 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/vizwiz.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python -m llava.eval.model_vqa_loader \ 4 | --model-path liuhaotian/llava-v1.5-13b \ 5 | --question-file ./playground/data/eval/vizwiz/llava_test.jsonl \ 6 | --image-folder ./playground/data/eval/vizwiz/test \ 7 | --answers-file ./playground/data/eval/vizwiz/answers/llava-v1.5-13b.jsonl \ 8 | --temperature 0 \ 9 | --conv-mode vicuna_v1 10 | 11 | python scripts/convert_vizwiz_for_submission.py \ 12 | --annotation-file ./playground/data/eval/vizwiz/llava_test.jsonl \ 13 | --result-file ./playground/data/eval/vizwiz/answers/llava-v1.5-13b.jsonl \ 14 | --result-upload-file ./playground/data/eval/vizwiz/answers_upload/llava-v1.5-13b.json 15 | -------------------------------------------------------------------------------- /scripts/v1_5/eval/vqav2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}" 4 | IFS=',' read -ra GPULIST <<< "$gpu_list" 5 | 6 | CHUNKS=${#GPULIST[@]} 7 | 8 | CKPT="llava-v1.5-13b" 9 | SPLIT="llava_vqav2_mscoco_test-dev2015" 10 | 11 | for IDX in $(seq 0 $((CHUNKS-1))); do 12 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \ 13 | --model-path liuhaotian/llava-v1.5-13b \ 14 | --question-file ./playground/data/eval/vqav2/$SPLIT.jsonl \ 15 | --image-folder ./playground/data/eval/vqav2/test2015 \ 16 | --answers-file ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \ 17 | --num-chunks $CHUNKS \ 18 | --chunk-idx $IDX \ 19 | --temperature 0 \ 20 | --conv-mode vicuna_v1 & 21 | done 22 | 23 | wait 24 | 25 | output_file=./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/merge.jsonl 26 | 27 | # Clear out the output file if it exists. 28 | > "$output_file" 29 | 30 | # Loop through the indices and concatenate each file. 31 | for IDX in $(seq 0 $((CHUNKS-1))); do 32 | cat ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file" 33 | done 34 | 35 | python scripts/convert_vqav2_for_submission.py --split $SPLIT --ckpt $CKPT 36 | 37 | -------------------------------------------------------------------------------- /scripts/v1_5/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed llava/train/train_mem.py \ 4 | --deepspeed ./scripts/zero3.json \ 5 | --model_name_or_path lmsys/vicuna-13b-v1.5 \ 6 | --version v1 \ 7 | --data_path ./playground/data/llava_v1_5_mix665k.json \ 8 | --image_folder ./playground/data \ 9 | --vision_tower openai/clip-vit-large-patch14-336 \ 10 | --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \ 11 | --mm_projector_type mlp2x_gelu \ 12 | --mm_vision_select_layer -2 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --image_aspect_ratio pad \ 16 | --group_by_modality_length True \ 17 | --bf16 True \ 18 | --output_dir ./checkpoints/llava-v1.5-13b \ 19 | --num_train_epochs 1 \ 20 | --per_device_train_batch_size 16 \ 21 | --per_device_eval_batch_size 4 \ 22 | --gradient_accumulation_steps 1 \ 23 | --evaluation_strategy "no" \ 24 | --save_strategy "steps" \ 25 | --save_steps 50000 \ 26 | --save_total_limit 1 \ 27 | --learning_rate 2e-5 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.03 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --tf32 True \ 33 | --model_max_length 2048 \ 34 | --gradient_checkpointing True \ 35 | --dataloader_num_workers 4 \ 36 | --lazy_preprocess True \ 37 | --report_to wandb 38 | -------------------------------------------------------------------------------- /scripts/v1_5/finetune_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed llava/train/train_mem.py \ 4 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 5 | --deepspeed ./scripts/zero3.json \ 6 | --model_name_or_path lmsys/vicuna-13b-v1.5 \ 7 | --version v1 \ 8 | --data_path ./playground/data/llava_v1_5_mix665k.json \ 9 | --image_folder ./playground/data \ 10 | --vision_tower openai/clip-vit-large-patch14-336 \ 11 | --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \ 12 | --mm_projector_type mlp2x_gelu \ 13 | --mm_vision_select_layer -2 \ 14 | --mm_use_im_start_end False \ 15 | --mm_use_im_patch_token False \ 16 | --image_aspect_ratio pad \ 17 | --group_by_modality_length True \ 18 | --bf16 True \ 19 | --output_dir ./checkpoints/llava-v1.5-13b-lora \ 20 | --num_train_epochs 1 \ 21 | --per_device_train_batch_size 16 \ 22 | --per_device_eval_batch_size 4 \ 23 | --gradient_accumulation_steps 1 \ 24 | --evaluation_strategy "no" \ 25 | --save_strategy "steps" \ 26 | --save_steps 50000 \ 27 | --save_total_limit 1 \ 28 | --learning_rate 2e-4 \ 29 | --weight_decay 0. \ 30 | --warmup_ratio 0.03 \ 31 | --lr_scheduler_type "cosine" \ 32 | --logging_steps 1 \ 33 | --tf32 True \ 34 | --model_max_length 2048 \ 35 | --gradient_checkpointing True \ 36 | --dataloader_num_workers 4 \ 37 | --lazy_preprocess True \ 38 | --report_to wandb 39 | -------------------------------------------------------------------------------- /scripts/v1_5/finetune_task.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed llava/train/train_mem.py \ 4 | --deepspeed ./scripts/zero3.json \ 5 | --model_name_or_path liuhaotian/llava-v1.5-13b \ 6 | --version v1 \ 7 | --data_path ./playground/data/llava_v1_5_mix665k.json \ 8 | --image_folder ./playground/data \ 9 | --vision_tower openai/clip-vit-large-patch14-336 \ 10 | --mm_projector_type mlp2x_gelu \ 11 | --mm_vision_select_layer -2 \ 12 | --mm_use_im_start_end False \ 13 | --mm_use_im_patch_token False \ 14 | --image_aspect_ratio pad \ 15 | --group_by_modality_length True \ 16 | --bf16 True \ 17 | --output_dir ./checkpoints/llava-v1.5-13b-task \ 18 | --num_train_epochs 1 \ 19 | --per_device_train_batch_size 16 \ 20 | --per_device_eval_batch_size 4 \ 21 | --gradient_accumulation_steps 1 \ 22 | --evaluation_strategy "no" \ 23 | --save_strategy "steps" \ 24 | --save_steps 50000 \ 25 | --save_total_limit 1 \ 26 | --learning_rate 2e-5 \ 27 | --weight_decay 0. \ 28 | --warmup_ratio 0.03 \ 29 | --lr_scheduler_type "cosine" \ 30 | --logging_steps 1 \ 31 | --tf32 True \ 32 | --model_max_length 2048 \ 33 | --gradient_checkpointing True \ 34 | --dataloader_num_workers 4 \ 35 | --lazy_preprocess True \ 36 | --report_to wandb 37 | -------------------------------------------------------------------------------- /scripts/v1_5/finetune_task_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed llava/train/train_mem.py \ 4 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 5 | --deepspeed ./scripts/zero3.json \ 6 | --model_name_or_path liuhaotian/llava-v1.5-13b \ 7 | --version v1 \ 8 | --data_path ./playground/data/llava_v1_5_mix665k.json \ 9 | --image_folder ./playground/data \ 10 | --vision_tower openai/clip-vit-large-patch14-336 \ 11 | --mm_projector_type mlp2x_gelu \ 12 | --mm_vision_select_layer -2 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --image_aspect_ratio pad \ 16 | --group_by_modality_length True \ 17 | --bf16 True \ 18 | --output_dir ./checkpoints/llava-v1.5-13b-task-lora \ 19 | --num_train_epochs 1 \ 20 | --per_device_train_batch_size 16 \ 21 | --per_device_eval_batch_size 4 \ 22 | --gradient_accumulation_steps 1 \ 23 | --evaluation_strategy "no" \ 24 | --save_strategy "steps" \ 25 | --save_steps 50000 \ 26 | --save_total_limit 1 \ 27 | --learning_rate 2e-4 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0.03 \ 30 | --lr_scheduler_type "cosine" \ 31 | --logging_steps 1 \ 32 | --tf32 True \ 33 | --model_max_length 2048 \ 34 | --gradient_checkpointing True \ 35 | --dataloader_num_workers 4 \ 36 | --lazy_preprocess True \ 37 | --report_to wandb 38 | -------------------------------------------------------------------------------- /scripts/v1_5/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | deepspeed llava/train/train_mem.py \ 4 | --deepspeed ./scripts/zero2.json \ 5 | --model_name_or_path lmsys/vicuna-13b-v1.5 \ 6 | --version plain \ 7 | --data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \ 8 | --image_folder ./playground/data/LLaVA-Pretrain/images \ 9 | --vision_tower openai/clip-vit-large-patch14-336 \ 10 | --mm_projector_type mlp2x_gelu \ 11 | --tune_mm_mlp_adapter True \ 12 | --mm_vision_select_layer -2 \ 13 | --mm_use_im_start_end False \ 14 | --mm_use_im_patch_token False \ 15 | --bf16 True \ 16 | --output_dir ./checkpoints/llava-v1.5-13b-pretrain \ 17 | --num_train_epochs 1 \ 18 | --per_device_train_batch_size 32 \ 19 | --per_device_eval_batch_size 4 \ 20 | --gradient_accumulation_steps 1 \ 21 | --evaluation_strategy "no" \ 22 | --save_strategy "steps" \ 23 | --save_steps 24000 \ 24 | --save_total_limit 1 \ 25 | --learning_rate 1e-3 \ 26 | --weight_decay 0. \ 27 | --warmup_ratio 0.03 \ 28 | --lr_scheduler_type "cosine" \ 29 | --logging_steps 1 \ 30 | --tf32 True \ 31 | --model_max_length 2048 \ 32 | --gradient_checkpointing True \ 33 | --dataloader_num_workers 4 \ 34 | --lazy_preprocess True \ 35 | --report_to wandb 36 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } -------------------------------------------------------------------------------- /taming_transformers/License.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, 17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE 19 | OR OTHER DEALINGS IN THE SOFTWARE./ 20 | -------------------------------------------------------------------------------- /taming_transformers/ckpt/model.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.vqgan.GumbelVQ 4 | params: 5 | kl_weight: 1.0e-08 6 | embed_dim: 256 7 | n_embed: 8192 8 | monitor: val/rec_loss 9 | temperature_scheduler_config: 10 | target: taming.lr_scheduler.LambdaWarmUpCosineScheduler 11 | params: 12 | warm_up_steps: 0 13 | max_decay_steps: 1000001 14 | lr_start: 0.9 15 | lr_max: 0.9 16 | lr_min: 1.0e-06 17 | ddconfig: 18 | double_z: false 19 | z_channels: 256 20 | resolution: 256 21 | in_channels: 3 22 | out_ch: 3 23 | ch: 128 24 | ch_mult: 25 | - 1 26 | - 1 27 | - 2 28 | - 4 29 | num_res_blocks: 2 30 | attn_resolutions: 31 | - 32 32 | dropout: 0.0 33 | lossconfig: 34 | target: taming.modules.losses.vqperceptual.DummyLoss 35 | -------------------------------------------------------------------------------- /taming_transformers/environment.yaml: -------------------------------------------------------------------------------- 1 | name: taming 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=10.2 9 | - pytorch=1.7.0 10 | - torchvision=0.8.1 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - opencv-python==4.1.2.30 15 | - pudb==2019.2 16 | - imageio==2.9.0 17 | - imageio-ffmpeg==0.4.2 18 | - pytorch-lightning==1.0.8 19 | - omegaconf==2.0.0 20 | - test-tube>=0.7.5 21 | - streamlit>=0.73.1 22 | - einops==0.3.0 23 | - more-itertools>=8.0.0 24 | - transformers==4.3.1 25 | - -e . 26 | -------------------------------------------------------------------------------- /taming_transformers/idx2img.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | 5 | import yaml 6 | import torch 7 | from omegaconf import OmegaConf 8 | from .taming.models.vqgan import VQModel, GumbelVQ 9 | import requests 10 | import PIL 11 | from PIL import Image 12 | from PIL import ImageDraw, ImageFont 13 | import numpy as np 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | import torchvision.transforms as T 18 | import torchvision.transforms.functional as TF 19 | import pickle 20 | import os, sys 21 | 22 | torch.set_grad_enabled(False) 23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 24 | 25 | def preprocess_vqgan(x): 26 | x = 2.*x - 1. 27 | return x 28 | 29 | def custom_to_pil(x, save_path): 30 | x = x.detach().cpu() 31 | x = torch.clamp(x, -1., 1.) 32 | x = (x + 1.)/2. 33 | x = x[0] 34 | x = x.permute(1,2,0).numpy() 35 | x = (255*x).astype(np.uint8) 36 | x = Image.fromarray(x) 37 | if not x.mode == "RGB": 38 | x = x.convert("RGB") 39 | x.save(save_path) 40 | return x 41 | 42 | def sample2img(x): 43 | x = x.detach().cpu() 44 | x = torch.clamp(x, -1., 1.) 45 | x = (x + 1.)/2. 46 | x = x[0] 47 | x = x.permute(1,2,0).numpy() 48 | x = (255*x).astype(np.uint8) 49 | x = Image.fromarray(x) 50 | if not x.mode == "RGB": 51 | x = x.convert("RGB") 52 | return x 53 | 54 | def load_config(config_path, display=False): 55 | config = OmegaConf.load(config_path) 56 | if display: 57 | print(yaml.dump(OmegaConf.to_container(config))) 58 | return config 59 | 60 | def load_vqgan(config, ckpt_path=None, is_gumbel=False): 61 | if is_gumbel: 62 | model = GumbelVQ(**config.model.params) 63 | else: 64 | model = VQModel(**config.model.params) 65 | if ckpt_path is not None: 66 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"] 67 | missing, unexpected = model.load_state_dict(sd, strict=False) 68 | return model.eval() 69 | 70 | dir_path = os.path.dirname(__file__) 71 | 72 | config = load_config(os.path.join(dir_path, 'ckpt/model.yaml'), display=False) 73 | model = load_vqgan(config, ckpt_path=os.path.join(dir_path, 'ckpt/last.ckpt'), is_gumbel=True).to(device) 74 | 75 | @torch.no_grad() 76 | def decode_to_img(index, zshape=torch.randn((1,256,32,32)).shape): 77 | global model 78 | bhwc = (zshape[0],zshape[2],zshape[3],zshape[1]) 79 | quant_z = model.quantize.get_codebook_entry( 80 | index.reshape(-1), shape=bhwc) 81 | x = model.decode(quant_z) 82 | return x 83 | 84 | def preprocess(img, target_image_size=512): 85 | img = TF.resize(img, (target_image_size, target_image_size), interpolation=PIL.Image.LANCZOS) 86 | img = torch.unsqueeze(T.ToTensor()(img), 0) 87 | return img 88 | 89 | @torch.no_grad() 90 | def img2idx(image_path): 91 | global model 92 | image = Image.open((image_path)).convert('RGB') 93 | img = preprocess_vqgan(preprocess(image, target_image_size=256).to(model.device)) 94 | 95 | z, _, [_, _, indices] = model.encode(img) 96 | return z, indices 97 | 98 | @torch.no_grad() 99 | def idx2img(idx_tensor, save_path): 100 | x = decode_to_img(idx_tensor) 101 | custom_to_pil(x, save_path) 102 | 103 | 104 | -------------------------------------------------------------------------------- /taming_transformers/scripts/extract_depth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from tqdm import trange 5 | from PIL import Image 6 | 7 | 8 | def get_state(gpu): 9 | import torch 10 | midas = torch.hub.load("intel-isl/MiDaS", "MiDaS") 11 | if gpu: 12 | midas.cuda() 13 | midas.eval() 14 | 15 | midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") 16 | transform = midas_transforms.default_transform 17 | 18 | state = {"model": midas, 19 | "transform": transform} 20 | return state 21 | 22 | 23 | def depth_to_rgba(x): 24 | assert x.dtype == np.float32 25 | assert len(x.shape) == 2 26 | y = x.copy() 27 | y.dtype = np.uint8 28 | y = y.reshape(x.shape+(4,)) 29 | return np.ascontiguousarray(y) 30 | 31 | 32 | def rgba_to_depth(x): 33 | assert x.dtype == np.uint8 34 | assert len(x.shape) == 3 and x.shape[2] == 4 35 | y = x.copy() 36 | y.dtype = np.float32 37 | y = y.reshape(x.shape[:2]) 38 | return np.ascontiguousarray(y) 39 | 40 | 41 | def run(x, state): 42 | model = state["model"] 43 | transform = state["transform"] 44 | hw = x.shape[:2] 45 | with torch.no_grad(): 46 | prediction = model(transform((x + 1.0) * 127.5).cuda()) 47 | prediction = torch.nn.functional.interpolate( 48 | prediction.unsqueeze(1), 49 | size=hw, 50 | mode="bicubic", 51 | align_corners=False, 52 | ).squeeze() 53 | output = prediction.cpu().numpy() 54 | return output 55 | 56 | 57 | def get_filename(relpath, level=-2): 58 | # save class folder structure and filename: 59 | fn = relpath.split(os.sep)[level:] 60 | folder = fn[-2] 61 | file = fn[-1].split('.')[0] 62 | return folder, file 63 | 64 | 65 | def save_depth(dataset, path, debug=False): 66 | os.makedirs(path) 67 | N = len(dset) 68 | if debug: 69 | N = 10 70 | state = get_state(gpu=True) 71 | for idx in trange(N, desc="Data"): 72 | ex = dataset[idx] 73 | image, relpath = ex["image"], ex["relpath"] 74 | folder, filename = get_filename(relpath) 75 | # prepare 76 | folderabspath = os.path.join(path, folder) 77 | os.makedirs(folderabspath, exist_ok=True) 78 | savepath = os.path.join(folderabspath, filename) 79 | # run model 80 | xout = run(image, state) 81 | I = depth_to_rgba(xout) 82 | Image.fromarray(I).save("{}.png".format(savepath)) 83 | 84 | 85 | if __name__ == "__main__": 86 | from taming.data.imagenet import ImageNetTrain, ImageNetValidation 87 | out = "data/imagenet_depth" 88 | if not os.path.exists(out): 89 | print("Please create a folder or symlink '{}' to extract depth data ".format(out) + 90 | "(be prepared that the output size will be larger than ImageNet itself).") 91 | exit(1) 92 | 93 | # go 94 | dset = ImageNetValidation() 95 | abspath = os.path.join(out, "val") 96 | if os.path.exists(abspath): 97 | print("{} exists - not doing anything.".format(abspath)) 98 | else: 99 | print("preparing {}".format(abspath)) 100 | save_depth(dset, abspath) 101 | print("done with validation split") 102 | 103 | dset = ImageNetTrain() 104 | abspath = os.path.join(out, "train") 105 | if os.path.exists(abspath): 106 | print("{} exists - not doing anything.".format(abspath)) 107 | else: 108 | print("preparing {}".format(abspath)) 109 | save_depth(dset, abspath) 110 | print("done with train split") 111 | 112 | print("done done.") 113 | -------------------------------------------------------------------------------- /taming_transformers/scripts/extract_segmentation.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import numpy as np 3 | import scipy 4 | import torch 5 | import torch.nn as nn 6 | from scipy import ndimage 7 | from tqdm import tqdm, trange 8 | from PIL import Image 9 | import torch.hub 10 | import torchvision 11 | import torch.nn.functional as F 12 | 13 | # download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from 14 | # https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth 15 | # and put the path here 16 | CKPT_PATH = "TODO" 17 | 18 | rescale = lambda x: (x + 1.) / 2. 19 | 20 | def rescale_bgr(x): 21 | x = (x+1)*127.5 22 | x = torch.flip(x, dims=[0]) 23 | return x 24 | 25 | 26 | class COCOStuffSegmenter(nn.Module): 27 | def __init__(self, config): 28 | super().__init__() 29 | self.config = config 30 | self.n_labels = 182 31 | model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels) 32 | ckpt_path = CKPT_PATH 33 | model.load_state_dict(torch.load(ckpt_path)) 34 | self.model = model 35 | 36 | normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std) 37 | self.image_transform = torchvision.transforms.Compose([ 38 | torchvision.transforms.Lambda(lambda image: torch.stack( 39 | [normalize(rescale_bgr(x)) for x in image])) 40 | ]) 41 | 42 | def forward(self, x, upsample=None): 43 | x = self._pre_process(x) 44 | x = self.model(x) 45 | if upsample is not None: 46 | x = torch.nn.functional.upsample_bilinear(x, size=upsample) 47 | return x 48 | 49 | def _pre_process(self, x): 50 | x = self.image_transform(x) 51 | return x 52 | 53 | @property 54 | def mean(self): 55 | # bgr 56 | return [104.008, 116.669, 122.675] 57 | 58 | @property 59 | def std(self): 60 | return [1.0, 1.0, 1.0] 61 | 62 | @property 63 | def input_size(self): 64 | return [3, 224, 224] 65 | 66 | 67 | def run_model(img, model): 68 | model = model.eval() 69 | with torch.no_grad(): 70 | segmentation = model(img, upsample=(img.shape[2], img.shape[3])) 71 | segmentation = torch.argmax(segmentation, dim=1, keepdim=True) 72 | return segmentation.detach().cpu() 73 | 74 | 75 | def get_input(batch, k): 76 | x = batch[k] 77 | if len(x.shape) == 3: 78 | x = x[..., None] 79 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format) 80 | return x.float() 81 | 82 | 83 | def save_segmentation(segmentation, path): 84 | # --> class label to uint8, save as png 85 | os.makedirs(os.path.dirname(path), exist_ok=True) 86 | assert len(segmentation.shape)==4 87 | assert segmentation.shape[0]==1 88 | for seg in segmentation: 89 | seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8) 90 | seg = Image.fromarray(seg) 91 | seg.save(path) 92 | 93 | 94 | def iterate_dataset(dataloader, destpath, model): 95 | os.makedirs(destpath, exist_ok=True) 96 | num_processed = 0 97 | for i, batch in tqdm(enumerate(dataloader), desc="Data"): 98 | try: 99 | img = get_input(batch, "image") 100 | img = img.cuda() 101 | seg = run_model(img, model) 102 | 103 | path = batch["relative_file_path_"][0] 104 | path = os.path.splitext(path)[0] 105 | 106 | path = os.path.join(destpath, path + ".png") 107 | save_segmentation(seg, path) 108 | num_processed += 1 109 | except Exception as e: 110 | print(e) 111 | print("but anyhow..") 112 | 113 | print("Processed {} files. Bye.".format(num_processed)) 114 | 115 | 116 | from taming.data.sflckr import Examples 117 | from torch.utils.data import DataLoader 118 | 119 | if __name__ == "__main__": 120 | dest = sys.argv[1] 121 | batchsize = 1 122 | print("Running with batch-size {}, saving to {}...".format(batchsize, dest)) 123 | 124 | model = COCOStuffSegmenter({}).cuda() 125 | print("Instantiated model.") 126 | 127 | dataset = Examples() 128 | dloader = DataLoader(dataset, batch_size=batchsize) 129 | iterate_dataset(dataloader=dloader, destpath=dest, model=model) 130 | print("done.") 131 | -------------------------------------------------------------------------------- /taming_transformers/scripts/extract_submodel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import sys 3 | 4 | if __name__ == "__main__": 5 | inpath = sys.argv[1] 6 | outpath = sys.argv[2] 7 | submodel = "cond_stage_model" 8 | if len(sys.argv) > 3: 9 | submodel = sys.argv[3] 10 | 11 | print("Extracting {} from {} to {}.".format(submodel, inpath, outpath)) 12 | 13 | sd = torch.load(inpath, map_location="cpu") 14 | new_sd = {"state_dict": dict((k.split(".", 1)[-1],v) 15 | for k,v in sd["state_dict"].items() 16 | if k.startswith("cond_stage_model"))} 17 | torch.save(new_sd, outpath) 18 | -------------------------------------------------------------------------------- /taming_transformers/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='taming-transformers', 5 | version='0.0.1', 6 | description='Taming Transformers for High-Resolution Image Synthesis', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) 14 | -------------------------------------------------------------------------------- /taming_transformers/taming/data/base.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import numpy as np 3 | import albumentations 4 | from PIL import Image 5 | from torch.utils.data import Dataset, ConcatDataset 6 | 7 | 8 | class ConcatDatasetWithIndex(ConcatDataset): 9 | """Modified from original pytorch code to return dataset idx""" 10 | def __getitem__(self, idx): 11 | if idx < 0: 12 | if -idx > len(self): 13 | raise ValueError("absolute value of index should not exceed dataset length") 14 | idx = len(self) + idx 15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) 16 | if dataset_idx == 0: 17 | sample_idx = idx 18 | else: 19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] 20 | return self.datasets[dataset_idx][sample_idx], dataset_idx 21 | 22 | 23 | class ImagePaths(Dataset): 24 | def __init__(self, paths, size=None, random_crop=False, labels=None): 25 | self.size = size 26 | self.random_crop = random_crop 27 | 28 | self.labels = dict() if labels is None else labels 29 | self.labels["file_path_"] = paths 30 | self._length = len(paths) 31 | 32 | if self.size is not None and self.size > 0: 33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) 34 | if not self.random_crop: 35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) 36 | else: 37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) 38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) 39 | else: 40 | self.preprocessor = lambda **kwargs: kwargs 41 | 42 | def __len__(self): 43 | return self._length 44 | 45 | def preprocess_image(self, image_path): 46 | image = Image.open(image_path) 47 | if not image.mode == "RGB": 48 | image = image.convert("RGB") 49 | image = np.array(image).astype(np.uint8) 50 | image = self.preprocessor(image=image)["image"] 51 | image = (image/127.5 - 1.0).astype(np.float32) 52 | return image 53 | 54 | def __getitem__(self, i): 55 | example = dict() 56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i]) 57 | for k in self.labels: 58 | example[k] = self.labels[k][i] 59 | return example 60 | 61 | 62 | class NumpyPaths(ImagePaths): 63 | def preprocess_image(self, image_path): 64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 65 | image = np.transpose(image, (1,2,0)) 66 | image = Image.fromarray(image, mode="RGB") 67 | image = np.array(image).astype(np.uint8) 68 | image = self.preprocessor(image=image)["image"] 69 | image = (image/127.5 - 1.0).astype(np.float32) 70 | return image 71 | -------------------------------------------------------------------------------- /taming_transformers/taming/data/conditional_builder/objects_bbox.py: -------------------------------------------------------------------------------- 1 | from itertools import cycle 2 | from typing import List, Tuple, Callable, Optional 3 | 4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont 5 | from more_itertools.recipes import grouper 6 | from taming.data.image_transforms import convert_pil_to_tensor 7 | from torch import LongTensor, Tensor 8 | 9 | from taming.data.helper_types import BoundingBox, Annotation 10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder 11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ 12 | pad_list, get_plot_font_size, absolute_bbox 13 | 14 | 15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): 16 | @property 17 | def object_descriptor_length(self) -> int: 18 | return 3 19 | 20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: 21 | object_triples = [ 22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) 23 | for ann in annotations 24 | ] 25 | empty_triple = (self.none, self.none, self.none) 26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) 27 | return object_triples 28 | 29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: 30 | conditional_list = conditional.tolist() 31 | crop_coordinates = None 32 | if self.encode_crop: 33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) 34 | conditional_list = conditional_list[:-2] 35 | object_triples = grouper(conditional_list, 3) 36 | assert conditional.shape[0] == self.embedding_dim 37 | return [ 38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) 39 | for object_triple in object_triples if object_triple[0] != self.none 40 | ], crop_coordinates 41 | 42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], 43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor: 44 | plot = pil_image.new('RGB', figure_size, WHITE) 45 | draw = pil_img_draw.Draw(plot) 46 | font = ImageFont.truetype( 47 | "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", 48 | size=get_plot_font_size(font_size, figure_size) 49 | ) 50 | width, height = plot.size 51 | description, crop_coordinates = self.inverse_build(conditional) 52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): 53 | annotation = self.representation_to_annotation(representation) 54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) 55 | bbox = absolute_bbox(bbox, width, height) 56 | draw.rectangle(bbox, outline=color, width=line_width) 57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) 58 | if crop_coordinates is not None: 59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) 60 | return convert_pil_to_tensor(plot) / 127.5 - 1. 61 | -------------------------------------------------------------------------------- /taming_transformers/taming/data/conditional_builder/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import List, Any, Tuple, Optional 3 | 4 | from taming.data.helper_types import BoundingBox, Annotation 5 | 6 | # source: seaborn, color palette tab10 7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), 8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] 9 | BLACK = (0, 0, 0) 10 | GRAY_75 = (63, 63, 63) 11 | GRAY_50 = (127, 127, 127) 12 | GRAY_25 = (191, 191, 191) 13 | WHITE = (255, 255, 255) 14 | FULL_CROP = (0., 0., 1., 1.) 15 | 16 | 17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: 18 | """ 19 | Give intersection area of two rectangles. 20 | @param rectangle1: (x0, y0, w, h) of first rectangle 21 | @param rectangle2: (x0, y0, w, h) of second rectangle 22 | """ 23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] 24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] 25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) 26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) 27 | return x_overlap * y_overlap 28 | 29 | 30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: 31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] 32 | 33 | 34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: 35 | bbox = relative_bbox 36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height 37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) 38 | 39 | 40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: 41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))] 42 | 43 | 44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ 45 | List[Annotation]: 46 | def clamp(x: float): 47 | return max(min(x, 1.), 0.) 48 | 49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox: 50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) 51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) 52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0) 53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0) 54 | if flip: 55 | x0 = 1 - (x0 + w) 56 | return x0, y0, w, h 57 | 58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] 59 | 60 | 61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: 62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] 63 | 64 | 65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: 66 | sl = slice(1) if short else slice(None) 67 | string = '' 68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): 69 | return string 70 | if annotation.is_group_of: 71 | string += 'group'[sl] + ',' 72 | if annotation.is_occluded: 73 | string += 'occluded'[sl] + ',' 74 | if annotation.is_depiction: 75 | string += 'depiction'[sl] + ',' 76 | if annotation.is_inside: 77 | string += 'inside'[sl] 78 | return '(' + string.strip(",") + ')' 79 | 80 | 81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: 82 | if font_size is None: 83 | font_size = 10 84 | if max(figure_size) >= 256: 85 | font_size = 12 86 | if max(figure_size) >= 512: 87 | font_size = 15 88 | return font_size 89 | 90 | 91 | def get_circle_size(figure_size: Tuple[int, int]) -> int: 92 | circle_size = 2 93 | if max(figure_size) >= 256: 94 | circle_size = 3 95 | if max(figure_size) >= 512: 96 | circle_size = 4 97 | return circle_size 98 | 99 | 100 | def load_object_from_string(object_string: str) -> Any: 101 | """ 102 | Source: https://stackoverflow.com/a/10773699 103 | """ 104 | module_name, class_name = object_string.rsplit(".", 1) 105 | return getattr(importlib.import_module(module_name), class_name) 106 | -------------------------------------------------------------------------------- /taming_transformers/taming/data/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import albumentations 4 | from torch.utils.data import Dataset 5 | 6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex 7 | 8 | 9 | class CustomBase(Dataset): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__() 12 | self.data = None 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, i): 18 | example = self.data[i] 19 | return example 20 | 21 | 22 | 23 | class CustomTrain(CustomBase): 24 | def __init__(self, size, training_images_list_file): 25 | super().__init__() 26 | with open(training_images_list_file, "r") as f: 27 | paths = f.read().splitlines() 28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 29 | 30 | 31 | class CustomTest(CustomBase): 32 | def __init__(self, size, test_images_list_file): 33 | super().__init__() 34 | with open(test_images_list_file, "r") as f: 35 | paths = f.read().splitlines() 36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False) 37 | 38 | 39 | -------------------------------------------------------------------------------- /taming_transformers/taming/data/helper_types.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Optional, NamedTuple, Union 2 | from PIL.Image import Image as pil_image 3 | from torch import Tensor 4 | 5 | try: 6 | from typing import Literal 7 | except ImportError: 8 | from typing_extensions import Literal 9 | 10 | Image = Union[Tensor, pil_image] 11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h 12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d'] 13 | SplitType = Literal['train', 'validation', 'test'] 14 | 15 | 16 | class ImageDescription(NamedTuple): 17 | id: int 18 | file_name: str 19 | original_size: Tuple[int, int] # w, h 20 | url: Optional[str] = None 21 | license: Optional[int] = None 22 | coco_url: Optional[str] = None 23 | date_captured: Optional[str] = None 24 | flickr_url: Optional[str] = None 25 | flickr_id: Optional[str] = None 26 | coco_id: Optional[str] = None 27 | 28 | 29 | class Category(NamedTuple): 30 | id: str 31 | super_category: Optional[str] 32 | name: str 33 | 34 | 35 | class Annotation(NamedTuple): 36 | area: float 37 | image_id: str 38 | bbox: BoundingBox 39 | category_no: int 40 | category_id: str 41 | id: Optional[int] = None 42 | source: Optional[str] = None 43 | confidence: Optional[float] = None 44 | is_group_of: Optional[bool] = None 45 | is_truncated: Optional[bool] = None 46 | is_occluded: Optional[bool] = None 47 | is_depiction: Optional[bool] = None 48 | is_inside: Optional[bool] = None 49 | segmentation: Optional[Dict] = None 50 | -------------------------------------------------------------------------------- /taming_transformers/taming/data/sflckr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import albumentations 5 | from PIL import Image 6 | from torch.utils.data import Dataset 7 | 8 | 9 | class SegmentationBase(Dataset): 10 | def __init__(self, 11 | data_csv, data_root, segmentation_root, 12 | size=None, random_crop=False, interpolation="bicubic", 13 | n_labels=182, shift_segmentation=False, 14 | ): 15 | self.n_labels = n_labels 16 | self.shift_segmentation = shift_segmentation 17 | self.data_csv = data_csv 18 | self.data_root = data_root 19 | self.segmentation_root = segmentation_root 20 | with open(self.data_csv, "r") as f: 21 | self.image_paths = f.read().splitlines() 22 | self._length = len(self.image_paths) 23 | self.labels = { 24 | "relative_file_path_": [l for l in self.image_paths], 25 | "file_path_": [os.path.join(self.data_root, l) 26 | for l in self.image_paths], 27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) 28 | for l in self.image_paths] 29 | } 30 | 31 | size = None if size is not None and size<=0 else size 32 | self.size = size 33 | if self.size is not None: 34 | self.interpolation = interpolation 35 | self.interpolation = { 36 | "nearest": cv2.INTER_NEAREST, 37 | "bilinear": cv2.INTER_LINEAR, 38 | "bicubic": cv2.INTER_CUBIC, 39 | "area": cv2.INTER_AREA, 40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] 41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 42 | interpolation=self.interpolation) 43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, 44 | interpolation=cv2.INTER_NEAREST) 45 | self.center_crop = not random_crop 46 | if self.center_crop: 47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) 48 | else: 49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) 50 | self.preprocessor = self.cropper 51 | 52 | def __len__(self): 53 | return self._length 54 | 55 | def __getitem__(self, i): 56 | example = dict((k, self.labels[k][i]) for k in self.labels) 57 | image = Image.open(example["file_path_"]) 58 | if not image.mode == "RGB": 59 | image = image.convert("RGB") 60 | image = np.array(image).astype(np.uint8) 61 | if self.size is not None: 62 | image = self.image_rescaler(image=image)["image"] 63 | segmentation = Image.open(example["segmentation_path_"]) 64 | assert segmentation.mode == "L", segmentation.mode 65 | segmentation = np.array(segmentation).astype(np.uint8) 66 | if self.shift_segmentation: 67 | # used to support segmentations containing unlabeled==255 label 68 | segmentation = segmentation+1 69 | if self.size is not None: 70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"] 71 | if self.size is not None: 72 | processed = self.preprocessor(image=image, 73 | mask=segmentation 74 | ) 75 | else: 76 | processed = {"image": image, 77 | "mask": segmentation 78 | } 79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) 80 | segmentation = processed["mask"] 81 | onehot = np.eye(self.n_labels)[segmentation] 82 | example["segmentation"] = onehot 83 | return example 84 | 85 | 86 | class Examples(SegmentationBase): 87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"): 88 | super().__init__(data_csv="data/sflckr_examples.txt", 89 | data_root="data/sflckr_images", 90 | segmentation_root="data/sflckr_segmentations", 91 | size=size, random_crop=random_crop, interpolation=interpolation) 92 | -------------------------------------------------------------------------------- /taming_transformers/taming/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n): 33 | return self.schedule(n) 34 | 35 | -------------------------------------------------------------------------------- /taming_transformers/taming/models/__pycache__/vqgan.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/models/__pycache__/vqgan.cpython-310.pyc -------------------------------------------------------------------------------- /taming_transformers/taming/models/dummy_cond_stage.py: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | class DummyCondStage: 5 | def __init__(self, conditional_key): 6 | self.conditional_key = conditional_key 7 | self.train = None 8 | 9 | def eval(self): 10 | return self 11 | 12 | @staticmethod 13 | def encode(c: Tensor): 14 | return c, None, (None, None, c) 15 | 16 | @staticmethod 17 | def decode(c: Tensor): 18 | return c 19 | 20 | @staticmethod 21 | def to_rgb(c: Tensor): 22 | return c 23 | -------------------------------------------------------------------------------- /taming_transformers/taming/modules/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /taming_transformers/taming/modules/diffusionmodules/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/diffusionmodules/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /taming_transformers/taming/modules/discriminator/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/discriminator/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /taming_transformers/taming/modules/discriminator/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | 4 | 5 | from taming.modules.util import ActNorm 6 | 7 | 8 | def weights_init(m): 9 | classname = m.__class__.__name__ 10 | if classname.find('Conv') != -1: 11 | nn.init.normal_(m.weight.data, 0.0, 0.02) 12 | elif classname.find('BatchNorm') != -1: 13 | nn.init.normal_(m.weight.data, 1.0, 0.02) 14 | nn.init.constant_(m.bias.data, 0) 15 | 16 | 17 | class NLayerDiscriminator(nn.Module): 18 | """Defines a PatchGAN discriminator as in Pix2Pix 19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 20 | """ 21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 22 | """Construct a PatchGAN discriminator 23 | Parameters: 24 | input_nc (int) -- the number of channels in input images 25 | ndf (int) -- the number of filters in the last conv layer 26 | n_layers (int) -- the number of conv layers in the discriminator 27 | norm_layer -- normalization layer 28 | """ 29 | super(NLayerDiscriminator, self).__init__() 30 | if not use_actnorm: 31 | norm_layer = nn.BatchNorm2d 32 | else: 33 | norm_layer = ActNorm 34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 35 | use_bias = norm_layer.func != nn.BatchNorm2d 36 | else: 37 | use_bias = norm_layer != nn.BatchNorm2d 38 | 39 | kw = 4 40 | padw = 1 41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 42 | nf_mult = 1 43 | nf_mult_prev = 1 44 | for n in range(1, n_layers): # gradually increase the number of filters 45 | nf_mult_prev = nf_mult 46 | nf_mult = min(2 ** n, 8) 47 | sequence += [ 48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 49 | norm_layer(ndf * nf_mult), 50 | nn.LeakyReLU(0.2, True) 51 | ] 52 | 53 | nf_mult_prev = nf_mult 54 | nf_mult = min(2 ** n_layers, 8) 55 | sequence += [ 56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 57 | norm_layer(ndf * nf_mult), 58 | nn.LeakyReLU(0.2, True) 59 | ] 60 | 61 | sequence += [ 62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 63 | self.main = nn.Sequential(*sequence) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | -------------------------------------------------------------------------------- /taming_transformers/taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from taming.modules.losses.vqperceptual import DummyLoss 2 | 3 | -------------------------------------------------------------------------------- /taming_transformers/taming/modules/losses/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/losses/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /taming_transformers/taming/modules/losses/__pycache__/lpips.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/losses/__pycache__/lpips.cpython-310.pyc -------------------------------------------------------------------------------- /taming_transformers/taming/modules/losses/__pycache__/vqperceptual.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/losses/__pycache__/vqperceptual.cpython-310.pyc -------------------------------------------------------------------------------- /taming_transformers/taming/modules/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class BCELoss(nn.Module): 6 | def forward(self, prediction, target): 7 | loss = F.binary_cross_entropy_with_logits(prediction,target) 8 | return loss, {} 9 | 10 | 11 | class BCELossWithQuant(nn.Module): 12 | def __init__(self, codebook_weight=1.): 13 | super().__init__() 14 | self.codebook_weight = codebook_weight 15 | 16 | def forward(self, qloss, target, prediction, split): 17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 18 | loss = bce_loss + self.codebook_weight*qloss 19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 20 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 21 | "{}/quant_loss".format(split): qloss.detach().mean() 22 | } 23 | -------------------------------------------------------------------------------- /taming_transformers/taming/modules/misc/coord.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CoordStage(object): 4 | def __init__(self, n_embed, down_factor): 5 | self.n_embed = n_embed 6 | self.down_factor = down_factor 7 | 8 | def eval(self): 9 | return self 10 | 11 | def encode(self, c): 12 | """fake vqmodel interface""" 13 | assert 0.0 <= c.min() and c.max() <= 1.0 14 | b,ch,h,w = c.shape 15 | assert ch == 1 16 | 17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, 18 | mode="area") 19 | c = c.clamp(0.0, 1.0) 20 | c = self.n_embed*c 21 | c_quant = c.round() 22 | c_ind = c_quant.to(dtype=torch.long) 23 | 24 | info = None, None, c_ind 25 | return c_quant, None, info 26 | 27 | def decode(self, c): 28 | c = c/self.n_embed 29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, 30 | mode="nearest") 31 | return c 32 | -------------------------------------------------------------------------------- /taming_transformers/taming/modules/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def count_params(model): 6 | total_params = sum(p.numel() for p in model.parameters()) 7 | return total_params 8 | 9 | 10 | class ActNorm(nn.Module): 11 | def __init__(self, num_features, logdet=False, affine=True, 12 | allow_reverse_init=False): 13 | assert affine 14 | super().__init__() 15 | self.logdet = logdet 16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 18 | self.allow_reverse_init = allow_reverse_init 19 | 20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 21 | 22 | def initialize(self, input): 23 | with torch.no_grad(): 24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 25 | mean = ( 26 | flatten.mean(1) 27 | .unsqueeze(1) 28 | .unsqueeze(2) 29 | .unsqueeze(3) 30 | .permute(1, 0, 2, 3) 31 | ) 32 | std = ( 33 | flatten.std(1) 34 | .unsqueeze(1) 35 | .unsqueeze(2) 36 | .unsqueeze(3) 37 | .permute(1, 0, 2, 3) 38 | ) 39 | 40 | self.loc.data.copy_(-mean) 41 | self.scale.data.copy_(1 / (std + 1e-6)) 42 | 43 | def forward(self, input, reverse=False): 44 | if reverse: 45 | return self.reverse(input) 46 | if len(input.shape) == 2: 47 | input = input[:,:,None,None] 48 | squeeze = True 49 | else: 50 | squeeze = False 51 | 52 | _, _, height, width = input.shape 53 | 54 | if self.training and self.initialized.item() == 0: 55 | self.initialize(input) 56 | self.initialized.fill_(1) 57 | 58 | h = self.scale * (input + self.loc) 59 | 60 | if squeeze: 61 | h = h.squeeze(-1).squeeze(-1) 62 | 63 | if self.logdet: 64 | log_abs = torch.log(torch.abs(self.scale)) 65 | logdet = height*width*torch.sum(log_abs) 66 | logdet = logdet * torch.ones(input.shape[0]).to(input) 67 | return h, logdet 68 | 69 | return h 70 | 71 | def reverse(self, output): 72 | if self.training and self.initialized.item() == 0: 73 | if not self.allow_reverse_init: 74 | raise RuntimeError( 75 | "Initializing ActNorm in reverse direction is " 76 | "disabled by default. Use allow_reverse_init=True to enable." 77 | ) 78 | else: 79 | self.initialize(output) 80 | self.initialized.fill_(1) 81 | 82 | if len(output.shape) == 2: 83 | output = output[:,:,None,None] 84 | squeeze = True 85 | else: 86 | squeeze = False 87 | 88 | h = output / self.scale - self.loc 89 | 90 | if squeeze: 91 | h = h.squeeze(-1).squeeze(-1) 92 | return h 93 | 94 | 95 | class AbstractEncoder(nn.Module): 96 | def __init__(self): 97 | super().__init__() 98 | 99 | def encode(self, *args, **kwargs): 100 | raise NotImplementedError 101 | 102 | 103 | class Labelator(AbstractEncoder): 104 | """Net2Net Interface for Class-Conditional Model""" 105 | def __init__(self, n_classes, quantize_interface=True): 106 | super().__init__() 107 | self.n_classes = n_classes 108 | self.quantize_interface = quantize_interface 109 | 110 | def encode(self, c): 111 | c = c[:,None] 112 | if self.quantize_interface: 113 | return c, None, [None, None, c.long()] 114 | return c 115 | 116 | 117 | class SOSProvider(AbstractEncoder): 118 | # for unconditional training 119 | def __init__(self, sos_token, quantize_interface=True): 120 | super().__init__() 121 | self.sos_token = sos_token 122 | self.quantize_interface = quantize_interface 123 | 124 | def encode(self, x): 125 | # get batch size from data and replicate sos_token 126 | c = torch.ones(x.shape[0], 1)*self.sos_token 127 | c = c.long().to(x.device) 128 | if self.quantize_interface: 129 | return c, None, [None, None, c] 130 | return c 131 | -------------------------------------------------------------------------------- /taming_transformers/taming/modules/vqvae/__pycache__/quantize.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/vqvae/__pycache__/quantize.cpython-310.pyc --------------------------------------------------------------------------------