├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── config.py
├── dockerignore.txt
├── editorconfig.txt
├── gitattributes.txt
├── gitignore.txt
├── images
├── Framework.png
├── HealthGPT.png
├── chatUI.jpg
└── intro.png
├── llava
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-310.pyc
│ ├── constants.cpython-310.pyc
│ ├── conversation.cpython-310.pyc
│ └── mm_utils.cpython-310.pyc
├── constants.py
├── conversation.py
├── demo
│ ├── __init__.py
│ ├── __pycache__
│ │ └── utils.cpython-310.pyc
│ ├── com_infer.py
│ ├── com_infer.sh
│ ├── com_infer_phi4.py
│ ├── com_infer_phi4.sh
│ ├── com_infer_qwen2_5.py
│ ├── com_infer_qwen2_5.sh
│ ├── gen_infer.py
│ ├── gen_infer.sh
│ └── utils.py
├── eval
│ ├── eval_gpt_review.py
│ ├── eval_gpt_review_bench.py
│ ├── eval_gpt_review_visual.py
│ ├── eval_pope.py
│ ├── eval_science_qa.py
│ ├── eval_science_qa_gpt4.py
│ ├── eval_science_qa_gpt4_requery.py
│ ├── eval_textvqa.py
│ ├── generate_webpage_data_from_table.py
│ ├── m4c_evaluator.py
│ ├── model_qa.py
│ ├── model_vqa.py
│ ├── model_vqa_loader.py
│ ├── model_vqa_mmbench.py
│ ├── model_vqa_science.py
│ ├── qa_baseline_gpt35.py
│ ├── run_llava.py
│ ├── summarize_gpt_review.py
│ ├── table
│ │ ├── answer
│ │ │ ├── answer_alpaca-13b.jsonl
│ │ │ ├── answer_bard.jsonl
│ │ │ ├── answer_gpt35.jsonl
│ │ │ ├── answer_llama-13b.jsonl
│ │ │ └── answer_vicuna-13b.jsonl
│ │ ├── caps_boxes_coco2014_val_80.jsonl
│ │ ├── model.jsonl
│ │ ├── prompt.jsonl
│ │ ├── question.jsonl
│ │ ├── results
│ │ │ ├── test_sqa_llava_13b_v0.json
│ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json
│ │ ├── review
│ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl
│ │ │ ├── review_bard_vicuna-13b.jsonl
│ │ │ ├── review_gpt35_vicuna-13b.jsonl
│ │ │ └── review_llama-13b_vicuna-13b.jsonl
│ │ ├── reviewer.jsonl
│ │ └── rule.json
│ └── webpage
│ │ ├── figures
│ │ ├── alpaca.png
│ │ ├── bard.jpg
│ │ ├── chatgpt.svg
│ │ ├── llama.jpg
│ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg
│ │ └── vicuna.jpeg
│ │ ├── index.html
│ │ ├── script.js
│ │ └── styles.css
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ └── llava_arch.cpython-310.pyc
│ ├── apply_delta.py
│ ├── builder.py
│ ├── consolidate.py
│ ├── language_model
│ │ ├── __pycache__
│ │ │ ├── llava_llama.cpython-310.pyc
│ │ │ ├── llava_mistral.cpython-310.pyc
│ │ │ ├── llava_mpt.cpython-310.pyc
│ │ │ └── llava_phi3.cpython-310.pyc
│ │ ├── llava_llama.py
│ │ ├── llava_mistral.py
│ │ ├── llava_mpt.py
│ │ ├── llava_phi3.py
│ │ └── llava_qwen.py
│ ├── llava_arch.py
│ ├── make_delta.py
│ ├── multimodal_encoder
│ │ ├── __pycache__
│ │ │ ├── builder.cpython-310.pyc
│ │ │ └── clip_encoder.cpython-310.pyc
│ │ ├── builder.py
│ │ └── clip_encoder.py
│ ├── multimodal_projector
│ │ ├── __pycache__
│ │ │ └── builder.cpython-310.pyc
│ │ └── builder.py
│ └── utils.py
├── peft
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── mapping.cpython-310.pyc
│ │ ├── mapping.cpython-39.pyc
│ │ ├── peft_model.cpython-310.pyc
│ │ └── peft_model.cpython-39.pyc
│ ├── mapping.py
│ ├── peft_model.py
│ ├── tuners
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-310.pyc
│ │ │ ├── __init__.cpython-39.pyc
│ │ │ ├── lora.cpython-310.pyc
│ │ │ ├── lora.cpython-310.pyc.139831930933664
│ │ │ ├── lora.cpython-310.pyc.140023258402944
│ │ │ ├── lora.cpython-310.pyc.140104265109632
│ │ │ ├── lora.cpython-310.pyc.140108281172096
│ │ │ ├── lora.cpython-310.pyc.140160632344704
│ │ │ ├── lora.cpython-310.pyc.140480966316160
│ │ │ ├── lora.cpython-310.pyc.140577961184688
│ │ │ ├── lora.cpython-39.pyc
│ │ │ ├── lora_moe.cpython-310.pyc
│ │ │ ├── p_tuning.cpython-310.pyc
│ │ │ ├── p_tuning.cpython-39.pyc
│ │ │ ├── prefix_tuning.cpython-310.pyc
│ │ │ ├── prefix_tuning.cpython-39.pyc
│ │ │ ├── prompt_tuning.cpython-310.pyc
│ │ │ └── prompt_tuning.cpython-39.pyc
│ │ ├── lora.py
│ │ ├── p_tuning.py
│ │ ├── prefix_tuning.py
│ │ └── prompt_tuning.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── adapters_utils.cpython-310.pyc
│ │ ├── adapters_utils.cpython-39.pyc
│ │ ├── config.cpython-310.pyc
│ │ ├── config.cpython-39.pyc
│ │ ├── other.cpython-310.pyc
│ │ ├── other.cpython-39.pyc
│ │ ├── save_and_load.cpython-310.pyc
│ │ └── save_and_load.cpython-39.pyc
│ │ ├── adapters_utils.py
│ │ ├── config.py
│ │ ├── other.py
│ │ └── save_and_load.py
├── serve
│ ├── __init__.py
│ ├── cli.py
│ ├── controller.py
│ ├── examples
│ │ ├── extreme_ironing.jpg
│ │ └── waterview.jpg
│ ├── gradio_web_server.py
│ ├── model_worker.py
│ ├── register_worker.py
│ ├── sglang_worker.py
│ └── test_message.py
└── utils.py
├── model.py
├── requirements.txt
├── requirements_qwen_2_5.txt
├── scripts
├── convert_gqa_for_eval.py
├── convert_mmbench_for_submission.py
├── convert_mmvet_for_eval.py
├── convert_seed_for_submission.py
├── convert_sqa_to_llava.py
├── convert_sqa_to_llava_base_prompt.py
├── convert_vizwiz_for_submission.py
├── convert_vqav2_for_submission.py
├── extract_mm_projector.py
├── finetune.sh
├── finetune_full_schedule.sh
├── finetune_lora.sh
├── finetune_qlora.sh
├── finetune_sqa.sh
├── merge_lora_weights.py
├── pretrain.sh
├── pretrain_xformers.sh
├── sqa_eval_batch.sh
├── sqa_eval_gather.sh
├── upload_pypi.sh
├── v1_5
│ ├── eval
│ │ ├── gqa.sh
│ │ ├── llavabench.sh
│ │ ├── mmbench.sh
│ │ ├── mmbench_cn.sh
│ │ ├── mme.sh
│ │ ├── mmvet.sh
│ │ ├── pope.sh
│ │ ├── qbench.sh
│ │ ├── qbench_zh.sh
│ │ ├── seed.sh
│ │ ├── sqa.sh
│ │ ├── textvqa.sh
│ │ ├── vizwiz.sh
│ │ └── vqav2.sh
│ ├── finetune.sh
│ ├── finetune_lora.sh
│ ├── finetune_task.sh
│ ├── finetune_task_lora.sh
│ └── pretrain.sh
├── zero2.json
├── zero3.json
└── zero3_offload.json
└── taming_transformers
├── License.txt
├── ckpt
└── model.yaml
├── environment.yaml
├── idx2img.py
├── main.py
├── scripts
├── extract_depth.py
├── extract_segmentation.py
├── extract_submodel.py
├── make_samples.py
├── make_scene_samples.py
├── sample_conditional.py
└── sample_fast.py
├── setup.py
└── taming
├── data
├── ade20k.py
├── annotated_objects_coco.py
├── annotated_objects_dataset.py
├── annotated_objects_open_images.py
├── base.py
├── coco.py
├── conditional_builder
│ ├── objects_bbox.py
│ ├── objects_center_points.py
│ └── utils.py
├── custom.py
├── faceshq.py
├── helper_types.py
├── image_transforms.py
├── imagenet.py
├── open_images_helper.py
├── sflckr.py
└── utils.py
├── lr_scheduler.py
├── models
├── __pycache__
│ └── vqgan.cpython-310.pyc
├── cond_transformer.py
├── dummy_cond_stage.py
└── vqgan.py
├── modules
├── __pycache__
│ └── util.cpython-310.pyc
├── diffusionmodules
│ ├── __pycache__
│ │ └── model.cpython-310.pyc
│ └── model.py
├── discriminator
│ ├── __pycache__
│ │ └── model.cpython-310.pyc
│ └── model.py
├── losses
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── lpips.cpython-310.pyc
│ │ └── vqperceptual.cpython-310.pyc
│ ├── lpips.py
│ ├── segmentation.py
│ └── vqperceptual.py
├── misc
│ └── coord.py
├── transformer
│ ├── mingpt.py
│ └── permuter.py
├── util.py
└── vqvae
│ ├── __pycache__
│ └── quantize.cpython-310.pyc
│ └── quantize.py
└── util.py
/app.py:
--------------------------------------------------------------------------------
1 | import traceback
2 |
3 | from model import HealthGPT, HealthGPT_Agent
4 | from config import HealthGPTConfig_M3_COM, HealthGPTConfig_M3_GEN, HealthGPTConfig_L14_COM
5 |
6 | configs = {
7 | "HealthGPT-M3-COM": HealthGPTConfig_M3_COM(),
8 | "HealthGPT-M3-GEN": HealthGPTConfig_M3_GEN(),
9 | "HealthGPT-L14-COM": HealthGPTConfig_L14_COM()
10 | }
11 |
12 | agent = HealthGPT_Agent(configs=configs, model_name=None)
13 |
14 | # HealthGPT interface
15 | import gradio as gr
16 | from PIL import Image, ImageDraw
17 |
18 | def process_input(option, model_name, text, image):
19 | if not text.strip():
20 | return gr.update(value="⚠️ Please input your question.", visible=True), None, gr.update(visible=True), gr.update(visible=False)
21 | try:
22 | if option == "Analyze Image":
23 | model_name = model_name + "-COM"
24 | try:
25 | agent.load_model(model_name=model_name)
26 | resp = agent.process(option, text, image)
27 | except Exception as e:
28 | agent.load_model(model_name=model_name)
29 | resp = agent.process(option, text, image)
30 | return resp, None, gr.update(visible=True), gr.update(visible=False)
31 |
32 | elif option == "Generate Image":
33 | model_name = model_name + "-GEN"
34 | try:
35 | agent.load_model(model_name=model_name)
36 | resp = agent.process(option, text, image)
37 | except Exception as e:
38 | agent.load_model(model_name=model_name)
39 | resp = agent.process(option, text, image)
40 | return None, resp, gr.update(visible=False), gr.update(visible=True)
41 | except Exception as e:
42 | print(traceback.format_exc())
43 | return gr.update(value=f"⚠️ {e.args[0]}", visible=True), None, gr.update(visible=True), gr.update(visible=False)
44 |
45 |
46 | with gr.Blocks() as demo:
47 | # gr.Markdown("# 🖼️ HealthGPT")
48 | gr.Markdown("
🖼️ HealthGPT
")
49 |
50 | # Option A / B
51 | with gr.Row():
52 | option = gr.Radio(["Analyze Image", "Generate Image"], label="🔍Choose the task", value="Analyze Image", interactive=True)
53 | model_name = gr.Radio(["HealthGPT-M3", "HealthGPT-L14"], label="🧠Choose the model", value="HealthGPT-M3", interactive=True)
54 |
55 | with gr.Row():
56 | with gr.Column():
57 | gr.Markdown("### 🔹 Input")
58 | text_input = gr.Textbox(label="Question", placeholder="Text here...", lines=3, value="Could you explain what this mass in the MRI means for my health? Is it very serious?")
59 | image_input = gr.Image(type="pil", label="Upload an image...")
60 |
61 | with gr.Column():
62 | gr.Markdown("### 🔹 Output")
63 | process_button = gr.Button("🚀 Process", variant="primary")
64 | text_output = gr.Textbox(label="HealthGPT Answer", visible=True, lines=20)
65 | image_output = gr.Image(label="Generated Image", visible=False)
66 |
67 | process_button.click(
68 | process_input,
69 | inputs=[option, model_name, text_input, image_input],
70 | outputs=[text_output, image_output, text_output, image_output] # 用 gr.update() 代替 bool
71 | )
72 |
73 | gr.Markdown("""### Terms of use
74 | By using this service, users are required to agree to the following terms:
75 | The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
76 | For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.""")
77 |
78 | demo.css = """footer {display: none !important;}"""
79 |
80 | # Start Gradio website
81 | demo.launch(server_name="0.0.0.0", server_port=5011, show_api=False)
82 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # HealthGPT config
2 | class HealthGPTConfig_M3_COM:
3 | model_name_or_path = "./Phi-3-mini-4k-instruct"
4 | dtype = "FP16"
5 | attn_implementation = None
6 | hlora_r = 64
7 | hlora_alpha = 128
8 | hlora_dropout = 0.0
9 | hlora_nums = 4
10 | vq_idx_nums = 8192
11 | instruct_template = "phi3_instruct"
12 | vit_path = "./clip-vit-large-patch14-336/"
13 | hlora_path = "./HealthGPT-M3/com_hlora_weights.bin"
14 | fusion_layer_path = None
15 | do_sample = False
16 | temperature = 0.0
17 | top_p = None
18 | num_beams = 1
19 | max_new_tokens = 2048
20 | task_type = "comprehension"
21 |
22 |
23 | class HealthGPTConfig_M3_GEN:
24 | model_name_or_path = "./Phi-3-mini-4k-instruct"
25 | dtype = "FP16"
26 | attn_implementation = None
27 | hlora_r = 256
28 | hlora_alpha = 512
29 | hlora_dropout = 0.0
30 | hlora_nums = 4
31 | vq_idx_nums = 8192
32 | instruct_template = "phi3_instruct"
33 | vit_path = "./clip-vit-large-patch14-336/"
34 | hlora_path = "./HealthGPT-M3/gen_hlora_weights.bin"
35 | fusion_layer_path = "./HealthGPT-M3/fusion_layer_weights.bin"
36 | do_sample = False
37 | temperature = 0.0
38 | top_p = None
39 | num_beams = 1
40 | max_new_tokens = 2048
41 | save_path = "output.png"
42 | task_type = "generation"
43 |
44 |
45 | class HealthGPTConfig_L14_COM:
46 | model_name_or_path = "./phi-4"
47 | dtype = "FP16"
48 | attn_implementation = None
49 | hlora_r = 32
50 | hlora_alpha = 64
51 | hlora_dropout = 0.0
52 | hlora_nums = 4
53 | vq_idx_nums = 8192
54 | instruct_template = "phi4_instruct"
55 | vit_path = "./clip-vit-large-patch14-336/"
56 | hlora_path = "./HealthGPT-L14/com_hlora_weights_phi4.bin"
57 | fusion_layer_path = None
58 | do_sample = False
59 | temperature = 0.0
60 | top_p = None
61 | num_beams = 1
62 | max_new_tokens = 2048
63 | task_type = "comprehension"
64 |
--------------------------------------------------------------------------------
/dockerignore.txt:
--------------------------------------------------------------------------------
1 | # The .dockerignore file excludes files from the container build process.
2 | #
3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file
4 |
5 | # Exclude Git files
6 | .git
7 | .github
8 | .gitignore
9 |
10 | # Exclude Python cache files
11 | __pycache__
12 | .mypy_cache
13 | .pytest_cache
14 | .ruff_cache
15 |
16 | # Exclude Python virtual environment
17 | /venv
18 |
19 | # Exclude some weights
20 | /openai
21 | /liuhaotian
22 |
--------------------------------------------------------------------------------
/editorconfig.txt:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | # Unix-style newlines with a newline ending every file
4 | [*]
5 | end_of_line = lf
6 | insert_final_newline = true
7 | trim_trailing_whitespace = true
8 | charset = utf-8
9 |
10 | # 4 space indentation
11 | [*.{py,json}]
12 | indent_style = space
13 | indent_size = 4
14 |
15 | # 2 space indentation
16 | [*.{md,sh,yaml,yml}]
17 | indent_style = space
18 | indent_size = 2
--------------------------------------------------------------------------------
/gitattributes.txt:
--------------------------------------------------------------------------------
1 | # https://git-scm.com/docs/gitattributes
2 |
3 | # Set the default behavior, in case people don't have core.autocrlf set.
4 | # https://git-scm.com/docs/gitattributes#_end_of_line_conversion
5 | * text=auto
6 |
7 | # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes
8 | # Source files
9 | # ============
10 | *.pxd text diff=python
11 | *.py text diff=python
12 | *.py3 text diff=python
13 | *.pyw text diff=python
14 | *.pyx text diff=python
15 | *.pyz text diff=python
16 | *.pyi text diff=python
17 |
18 | # Binary files
19 | # ============
20 | *.db binary
21 | *.p binary
22 | *.pkl binary
23 | *.pickle binary
24 | *.pyc binary export-ignore
25 | *.pyo binary export-ignore
26 | *.pyd binary
27 |
28 | # Jupyter notebook
29 | *.ipynb text eol=lf
30 |
--------------------------------------------------------------------------------
/gitignore.txt:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__
3 | *.pyc
4 | *.egg-info
5 | dist
6 |
7 | # Log
8 | *.log
9 | *.log.*
10 | *.json
11 | *.jsonl
12 |
13 | # Data
14 | !**/alpaca-data-conversation.json
15 |
16 | # Editor
17 | .idea
18 | *.swp
19 |
20 | # Other
21 | .DS_Store
22 | wandb
23 | output
24 |
25 | checkpoints
26 | ckpts*
27 |
28 | .ipynb_checkpoints
29 | *.ipynb
30 |
31 | # DevContainer
32 | !.devcontainer/*
33 |
34 | # Demo
35 | serve_images/
36 |
--------------------------------------------------------------------------------
/images/Framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/images/Framework.png
--------------------------------------------------------------------------------
/images/HealthGPT.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/images/HealthGPT.png
--------------------------------------------------------------------------------
/images/chatUI.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/images/chatUI.jpg
--------------------------------------------------------------------------------
/images/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/images/intro.png
--------------------------------------------------------------------------------
/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaPhiForCausalLM
--------------------------------------------------------------------------------
/llava/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/__pycache__/constants.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/__pycache__/constants.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/__pycache__/conversation.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/__pycache__/conversation.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/__pycache__/mm_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/__pycache__/mm_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 | IMAGE_PLACEHOLDER = ""
14 |
--------------------------------------------------------------------------------
/llava/demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/demo/__init__.py
--------------------------------------------------------------------------------
/llava/demo/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/demo/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/demo/com_infer.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_NAME_OR_PATH="microsoft/Phi-3-mini-4k-instruct"
4 | VIT_PATH="openai/clip-vit-large-patch14-336/"
5 | HLORA_PATH="com_hlora_weights.bin"
6 | FUSION_LAYER_PATH="fusion_layer_weights.bin"
7 |
8 | python3 com_infer.py \
9 | --model_name_or_path "$MODEL_NAME_OR_PATH" \
10 | --dtype "FP16" \
11 | --hlora_r "64" \
12 | --hlora_alpha "128" \
13 | --hlora_nums "4" \
14 | --vq_idx_nums "8192" \
15 | --instruct_template "phi3_instruct" \
16 | --vit_path "$VIT_PATH" \
17 | --hlora_path "$HLORA_PATH" \
18 | --fusion_layer_path "$FUSION_LAYER_PATH" \
19 | --question "Your question" \
20 | --img_path "path/to/image.jpg" \
21 |
--------------------------------------------------------------------------------
/llava/demo/com_infer_phi4.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_NAME_OR_PATH="microsoft/Phi-4"
4 | VIT_PATH="openai/clip-vit-large-patch14-336/"
5 | HLORA_PATH="com_hlora_weights_phi4.bin"
6 |
7 | python3 com_infer_phi4.py \
8 | --model_name_or_path "$MODEL_NAME_OR_PATH" \
9 | --dtype "FP16" \
10 | --hlora_r "32" \
11 | --hlora_alpha "64" \
12 | --hlora_nums "4" \
13 | --vq_idx_nums "8192" \
14 | --instruct_template "phi4_instruct" \
15 | --vit_path "$VIT_PATH" \
16 | --hlora_path "$HLORA_PATH" \
17 | --question "Your question" \
18 | --img_path "path/to/image.jpg"
19 |
--------------------------------------------------------------------------------
/llava/demo/com_infer_qwen2_5.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_NAME_OR_PATH="Qwen/Qwen2.5-32B-Instruct"
4 | VIT_PATH="openai/clip-vit-large-patch14-336/"
5 | HLORA_PATH="com_hlora_weights_QWEN_32B.bin"
6 |
7 | python3 com_infer_qwen2_5.py \
8 | --model_name_or_path "$MODEL_NAME_OR_PATH" \
9 | --dtype "FP16" \
10 | --hlora_r "32" \
11 | --hlora_alpha "64" \
12 | --hlora_nums "4" \
13 | --vq_idx_nums "8192" \
14 | --instruct_template "qwen_2" \
15 | --vit_path "$VIT_PATH" \
16 | --hlora_path "$HLORA_PATH" \
17 | --question "Your question" \
18 | --img_path "path/to/image.jpg"
19 |
--------------------------------------------------------------------------------
/llava/demo/gen_infer.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | MODEL_NAME_OR_PATH="microsoft/Phi-3-mini-4k-instruct"
4 | VIT_PATH="openai/clip-vit-large-patch14-336/"
5 | HLORA_PATH="gen_hlora_weights.bin"
6 | FUSION_LAYER_PATH="fusion_layer_weights.bin"
7 |
8 | python3 gen_infer.py \
9 | --model_name_or_path "$MODEL_NAME_OR_PATH" \
10 | --dtype "FP16" \
11 | --hlora_r "256" \
12 | --hlora_alpha "512" \
13 | --hlora_nums "4" \
14 | --vq_idx_nums "8192" \
15 | --instruct_template "phi3_instruct" \
16 | --vit_path "$VIT_PATH" \
17 | --hlora_path "$HLORA_PATH" \
18 | --fusion_layer_path "$FUSION_LAYER_PATH" \
19 | --question "Reconstruct the image." \
20 | --img_path "path/to/image.jpg" \
21 | --save_path "path/to/save.jpg"
22 |
--------------------------------------------------------------------------------
/llava/demo/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import transformers
3 | import tokenizers
4 | import os, sys
5 | from dataclasses import dataclass, field
6 | import argparse
7 | from PIL import Image
8 |
9 | def expand2square(pil_img, background_color):
10 | width, height = pil_img.size
11 | if width == height:
12 | return pil_img
13 | elif width > height:
14 | result = Image.new(pil_img.mode, (width, width), background_color)
15 | result.paste(pil_img, (0, (width - height) // 2))
16 | return result
17 | else:
18 | result = Image.new(pil_img.mode, (height, height), background_color)
19 | result.paste(pil_img, ((height - width) // 2, 0))
20 | return result
21 |
22 | def find_all_linear_names(model):
23 | cls = torch.nn.Linear
24 | lora_module_names = set()
25 | multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
26 | for name, module in model.named_modules():
27 | if any(mm_keyword in name for mm_keyword in multimodal_keywords):
28 | continue
29 | if isinstance(module, cls):
30 | names = name.split('.')
31 | lora_module_names.add(names[0] if len(names) == 1 else names[-1])
32 |
33 | if 'lm_head' in lora_module_names: # needed for 16-bit
34 | lora_module_names.remove('lm_head')
35 | return list(lora_module_names)
36 |
37 | def add_special_tokens_and_resize_model(tokenizer, model, vq_idx_nums=8192):
38 | if len(tokenizer.additional_special_tokens) != 0:
39 | return tokenizer.additional_special_tokens
40 | index_tokens = [f"" for i in range(vq_idx_nums)]
41 | special_tokens = {
42 | 'additional_special_tokens': [''] + index_tokens + [''] + ['']
43 | }
44 | num_new_tokens = tokenizer.add_special_tokens(special_tokens)
45 | model.resize_token_embeddings(len(tokenizer))
46 | if num_new_tokens > 0:
47 | input_embeddings = model.get_input_embeddings().weight.data
48 | output_embeddings = model.get_output_embeddings().weight.data
49 |
50 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
51 | dim=0, keepdim=True)
52 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
53 | dim=0, keepdim=True)
54 |
55 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
56 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
57 |
58 | return num_new_tokens
59 |
60 | com_vision_args = argparse.Namespace(
61 | freeze_backbone=False,
62 | mm_patch_merge_type='flat',
63 | mm_projector_type='mlp2x_gelu',
64 | mm_use_im_patch_token=False,
65 | mm_use_im_start_end=False,
66 | mm_vision_select_feature='patch',
67 | mm_vision_select_layer=-2,
68 | model_name_or_path=None,
69 | pretrain_mm_mlp_adapter=None,
70 | tune_mm_mlp_adapter=False,
71 | version=None,
72 | vision_tower=None
73 | )
74 |
75 | gen_vision_args = argparse.Namespace(
76 | freeze_backbone=False,
77 | mm_patch_merge_type='flat',
78 | mm_projector_type='mlp2x_gelu',
79 | mm_use_im_patch_token=False,
80 | mm_use_im_start_end=False,
81 | mm_vision_select_feature='patch',
82 | mm_vision_select_layer=1,
83 | model_name_or_path=None,
84 | pretrain_mm_mlp_adapter=None,
85 | tune_mm_mlp_adapter=False,
86 | version=None,
87 | vision_tower=None
88 | )
89 |
90 | def load_weights(model, hlora_path, fusion_layer_path=None):
91 | hlora_weights = torch.load(hlora_path)
92 | hlora_unexpected_keys = model.load_state_dict(hlora_weights, strict=False)[1]
93 | if hlora_unexpected_keys:
94 | print(f"Warning: Unexpected keys in hlora checkpoint: {hlora_unexpected_keys}")
95 |
96 | if fusion_layer_path:
97 | fusion_layer_weights = torch.load(fusion_layer_path)
98 | fusion_layer_unexpected_keys = model.load_state_dict(fusion_layer_weights, strict=False)[1]
99 | if fusion_layer_unexpected_keys:
100 | print(f"Warning: Unexpected keys in fusion_layer checkpoint: {fusion_layer_unexpected_keys}")
101 |
102 | return model
103 |
104 |
--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import tqdm
7 | import ray
8 | import time
9 |
10 | NUM_SECONDS_TO_SLEEP = 3
11 |
12 | @ray.remote(num_cpus=4)
13 | def get_eval(content: str, max_tokens: int):
14 | while True:
15 | try:
16 | response = openai.ChatCompletion.create(
17 | model='gpt-4',
18 | messages=[{
19 | 'role': 'system',
20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
21 | }, {
22 | 'role': 'user',
23 | 'content': content,
24 | }],
25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
26 | max_tokens=max_tokens,
27 | )
28 | break
29 | except openai.error.RateLimitError:
30 | pass
31 | except Exception as e:
32 | print(e)
33 | time.sleep(NUM_SECONDS_TO_SLEEP)
34 |
35 | print('success!')
36 | return response['choices'][0]['message']['content']
37 |
38 |
39 | def parse_score(review):
40 | try:
41 | score_pair = review.split('\n')[0]
42 | score_pair = score_pair.replace(',', ' ')
43 | sp = score_pair.split(' ')
44 | if len(sp) == 2:
45 | return [float(sp[0]), float(sp[1])]
46 | else:
47 | print('error', review)
48 | return [-1, -1]
49 | except Exception as e:
50 | print(e)
51 | print('error', review)
52 | return [-1, -1]
53 |
54 |
55 | if __name__ == '__main__':
56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
57 | parser.add_argument('-q', '--question')
58 | # parser.add_argument('-a', '--answer')
59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
60 | parser.add_argument('-r', '--rule')
61 | parser.add_argument('-o', '--output')
62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
63 | args = parser.parse_args()
64 |
65 | ray.init()
66 |
67 | f_q = open(os.path.expanduser(args.question))
68 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
69 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
71 |
72 | review_file = open(f'{args.output}', 'w')
73 |
74 | js_list = []
75 | handles = []
76 | idx = 0
77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
78 | # if idx == 1:
79 | # break
80 |
81 | ques = json.loads(ques_js)
82 | ans1 = json.loads(ans1_js)
83 | ans2 = json.loads(ans2_js)
84 |
85 | category = json.loads(ques_js)['category']
86 | if category in rule_dict:
87 | rule = rule_dict[category]
88 | else:
89 | rule = rule_dict['default']
90 | prompt = rule['prompt']
91 | role = rule['role']
92 | content = (f'[Question]\n{ques["text"]}\n\n'
93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
95 | f'[System]\n{prompt}\n\n')
96 | js_list.append({
97 | 'id': idx+1,
98 | 'question_id': ques['question_id'],
99 | 'answer1_id': ans1['answer_id'],
100 | 'answer2_id': ans2['answer_id'],
101 | 'category': category})
102 | idx += 1
103 | handles.append(get_eval.remote(content, args.max_tokens))
104 | # To avoid the rate limit set by OpenAI
105 | time.sleep(NUM_SECONDS_TO_SLEEP)
106 |
107 | reviews = ray.get(handles)
108 | for idx, review in enumerate(reviews):
109 | scores = parse_score(review)
110 | js_list[idx]['content'] = review
111 | js_list[idx]['tuple'] = scores
112 | review_file.write(json.dumps(js_list[idx]) + '\n')
113 | review_file.close()
114 |
--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review_visual.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import time
7 |
8 | NUM_SECONDS_TO_SLEEP = 0.5
9 |
10 |
11 | def get_eval(content: str, max_tokens: int):
12 | while True:
13 | try:
14 | response = openai.ChatCompletion.create(
15 | model='gpt-4-0314',
16 | messages=[{
17 | 'role': 'system',
18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19 | }, {
20 | 'role': 'user',
21 | 'content': content,
22 | }],
23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
24 | max_tokens=max_tokens,
25 | )
26 | break
27 | except openai.error.RateLimitError:
28 | pass
29 | except Exception as e:
30 | print(e)
31 | time.sleep(NUM_SECONDS_TO_SLEEP)
32 |
33 | return response['choices'][0]['message']['content']
34 |
35 |
36 | def parse_score(review):
37 | try:
38 | score_pair = review.split('\n')[0]
39 | score_pair = score_pair.replace(',', ' ')
40 | sp = score_pair.split(' ')
41 | if len(sp) == 2:
42 | return [float(sp[0]), float(sp[1])]
43 | else:
44 | print('error', review)
45 | return [-1, -1]
46 | except Exception as e:
47 | print(e)
48 | print('error', review)
49 | return [-1, -1]
50 |
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54 | parser.add_argument('-q', '--question')
55 | parser.add_argument('-c', '--context')
56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57 | parser.add_argument('-r', '--rule')
58 | parser.add_argument('-o', '--output')
59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60 | args = parser.parse_args()
61 |
62 | f_q = open(os.path.expanduser(args.question))
63 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66 |
67 | if os.path.isfile(os.path.expanduser(args.output)):
68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69 | else:
70 | cur_reviews = []
71 |
72 | review_file = open(f'{args.output}', 'a')
73 |
74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75 | image_to_context = {context['image']: context for context in context_list}
76 |
77 | handles = []
78 | idx = 0
79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80 | ques = json.loads(ques_js)
81 | ans1 = json.loads(ans1_js)
82 | ans2 = json.loads(ans2_js)
83 |
84 | inst = image_to_context[ques['image']]
85 | cap_str = '\n'.join(inst['captions'])
86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
87 |
88 | category = json.loads(ques_js)['category']
89 | if category in rule_dict:
90 | rule = rule_dict[category]
91 | else:
92 | assert False, f"Visual QA category not found in rule file: {category}."
93 | prompt = rule['prompt']
94 | role = rule['role']
95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
96 | f'[Question]\n{ques["text"]}\n\n'
97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
99 | f'[System]\n{prompt}\n\n')
100 | cur_js = {
101 | 'id': idx+1,
102 | 'question_id': ques['question_id'],
103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']),
104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
105 | 'category': category
106 | }
107 | if idx >= len(cur_reviews):
108 | review = get_eval(content, args.max_tokens)
109 | scores = parse_score(review)
110 | cur_js['content'] = review
111 | cur_js['tuple'] = scores
112 | review_file.write(json.dumps(cur_js) + '\n')
113 | review_file.flush()
114 | else:
115 | print(f'Skipping {idx} as we already have it.')
116 | idx += 1
117 | print(idx)
118 | review_file.close()
119 |
--------------------------------------------------------------------------------
/llava/eval/eval_pope.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | def eval_pope(answers, label_file):
6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
7 |
8 | for answer in answers:
9 | text = answer['text']
10 |
11 | # Only keep the first sentence
12 | if text.find('.') != -1:
13 | text = text.split('.')[0]
14 |
15 | text = text.replace(',', '')
16 | words = text.split(' ')
17 | if 'No' in words or 'not' in words or 'no' in words:
18 | answer['text'] = 'no'
19 | else:
20 | answer['text'] = 'yes'
21 |
22 | for i in range(len(label_list)):
23 | if label_list[i] == 'no':
24 | label_list[i] = 0
25 | else:
26 | label_list[i] = 1
27 |
28 | pred_list = []
29 | for answer in answers:
30 | if answer['text'] == 'no':
31 | pred_list.append(0)
32 | else:
33 | pred_list.append(1)
34 |
35 | pos = 1
36 | neg = 0
37 | yes_ratio = pred_list.count(1) / len(pred_list)
38 |
39 | TP, TN, FP, FN = 0, 0, 0, 0
40 | for pred, label in zip(pred_list, label_list):
41 | if pred == pos and label == pos:
42 | TP += 1
43 | elif pred == pos and label == neg:
44 | FP += 1
45 | elif pred == neg and label == neg:
46 | TN += 1
47 | elif pred == neg and label == pos:
48 | FN += 1
49 |
50 | print('TP\tFP\tTN\tFN\t')
51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
52 |
53 | precision = float(TP) / float(TP + FP)
54 | recall = float(TP) / float(TP + FN)
55 | f1 = 2*precision*recall / (precision + recall)
56 | acc = (TP + TN) / (TP + TN + FP + FN)
57 | print('Accuracy: {}'.format(acc))
58 | print('Precision: {}'.format(precision))
59 | print('Recall: {}'.format(recall))
60 | print('F1 score: {}'.format(f1))
61 | print('Yes ratio: {}'.format(yes_ratio))
62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )
63 |
64 | if __name__ == "__main__":
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument("--annotation-dir", type=str)
67 | parser.add_argument("--question-file", type=str)
68 | parser.add_argument("--result-file", type=str)
69 | args = parser.parse_args()
70 |
71 | questions = [json.loads(line) for line in open(args.question_file)]
72 | questions = {question['question_id']: question for question in questions}
73 | answers = [json.loads(q) for q in open(args.result_file)]
74 | for file in os.listdir(args.annotation_dir):
75 | assert file.startswith('coco_pope_')
76 | assert file.endswith('.json')
77 | category = file[10:-5]
78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
81 | print("====================================")
82 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 |
7 |
8 | def get_args():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--base-dir', type=str)
11 | parser.add_argument('--result-file', type=str)
12 | parser.add_argument('--output-file', type=str)
13 | parser.add_argument('--output-result', type=str)
14 | parser.add_argument('--split', type=str, default='test')
15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16 | return parser.parse_args()
17 |
18 |
19 | def convert_caps(results):
20 | fakecaps = []
21 | for result in results:
22 | image_id = result['question_id']
23 | caption = result['text']
24 | fakecaps.append({"image_id": int(image_id), "caption": caption})
25 | return fakecaps
26 |
27 |
28 | def get_pred_idx(prediction, choices, options):
29 | """
30 | Get the index (e.g. 2) from the prediction (e.g. 'C')
31 | """
32 | if prediction in options[:len(choices)]:
33 | return options.index(prediction)
34 | else:
35 | return -1
36 | return random.choice(range(len(choices)))
37 |
38 |
39 | if __name__ == "__main__":
40 | args = get_args()
41 |
42 | base_dir = args.base_dir
43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
44 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
45 | predictions = [json.loads(line) for line in open(args.result_file)]
46 | predictions = {pred['question_id']: pred for pred in predictions}
47 | split_problems = {idx: problems[idx] for idx in split_indices}
48 |
49 | results = {'correct': [], 'incorrect': []}
50 | sqa_results = {}
51 | sqa_results['acc'] = None
52 | sqa_results['correct'] = None
53 | sqa_results['count'] = None
54 | sqa_results['results'] = {}
55 | sqa_results['outputs'] = {}
56 |
57 | for prob_id, prob in split_problems.items():
58 | if prob_id not in predictions:
59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'}
60 | pred_text = 'FAILED'
61 | else:
62 | pred = predictions[prob_id]
63 | pred_text = pred['text']
64 |
65 | if pred_text in args.options:
66 | answer = pred_text
67 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
68 | answer = pred_text[0]
69 | else:
70 | pattern = re.compile(r'The answer is ([A-Z]).')
71 | res = pattern.findall(pred_text)
72 | if len(res) == 1:
73 | answer = res[0] # 'A', 'B', ...
74 | else:
75 | answer = "FAILED"
76 |
77 | pred_idx = get_pred_idx(answer, prob['choices'], args.options)
78 |
79 | analysis = {
80 | 'question_id': prob_id,
81 | 'parsed_ans': answer,
82 | 'ground_truth': args.options[prob['answer']],
83 | 'question': pred['prompt'],
84 | 'pred': pred_text,
85 | 'is_multimodal': '' in pred['prompt'],
86 | }
87 |
88 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
89 | sqa_results['outputs'][prob_id] = pred_text
90 |
91 | if pred_idx == prob['answer']:
92 | results['correct'].append(analysis)
93 | else:
94 | results['incorrect'].append(analysis)
95 |
96 | correct = len(results['correct'])
97 | total = len(results['correct']) + len(results['incorrect'])
98 |
99 | ###### IMG ######
100 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
101 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
102 | multimodal_total = multimodal_correct + multimodal_incorrect
103 | ###### IMG ######
104 |
105 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
106 |
107 | sqa_results['acc'] = correct / total * 100
108 | sqa_results['correct'] = correct
109 | sqa_results['count'] = total
110 |
111 | with open(args.output_file, 'w') as f:
112 | json.dump(results, f, indent=2)
113 | with open(args.output_result, 'w') as f:
114 | json.dump(sqa_results, f, indent=2)
115 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa_gpt4.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 | from collections import defaultdict
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--base-dir', type=str)
12 | parser.add_argument('--gpt4-result', type=str)
13 | parser.add_argument('--our-result', type=str)
14 | parser.add_argument('--split', type=str, default='test')
15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16 | return parser.parse_args()
17 |
18 |
19 | def convert_caps(results):
20 | fakecaps = []
21 | for result in results:
22 | image_id = result['question_id']
23 | caption = result['text']
24 | fakecaps.append({"image_id": int(image_id), "caption": caption})
25 | return fakecaps
26 |
27 |
28 | def get_pred_idx(prediction, choices, options):
29 | """
30 | Get the index (e.g. 2) from the prediction (e.g. 'C')
31 | """
32 | if prediction in options[:len(choices)]:
33 | return options.index(prediction)
34 | else:
35 | return random.choice(range(len(choices)))
36 |
37 |
38 | if __name__ == "__main__":
39 | args = get_args()
40 |
41 | base_dir = args.base_dir
42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
43 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
44 | our_predictions = [json.loads(line) for line in open(args.our_result)]
45 | our_predictions = {pred['question_id']: pred for pred in our_predictions}
46 | split_problems = {idx: problems[idx] for idx in split_indices}
47 |
48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
49 |
50 | results = defaultdict(lambda: 0)
51 |
52 | for prob_id, prob in split_problems.items():
53 | if prob_id not in our_predictions:
54 | continue
55 | if prob_id not in gpt4_predictions:
56 | continue
57 | our_pred = our_predictions[prob_id]['text']
58 | gpt4_pred = gpt4_predictions[prob_id]
59 |
60 | pattern = re.compile(r'The answer is ([A-Z]).')
61 | our_res = pattern.findall(our_pred)
62 | if len(our_res) == 1:
63 | our_answer = our_res[0] # 'A', 'B', ...
64 | else:
65 | our_answer = "FAILED"
66 | gpt4_res = pattern.findall(gpt4_pred)
67 | if len(gpt4_res) == 1:
68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ...
69 | else:
70 | gpt4_answer = "FAILED"
71 |
72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
74 |
75 | if gpt4_answer == 'FAILED':
76 | results['gpt4_failed'] += 1
77 | # continue
78 | gpt4_pred_idx = our_pred_idx
79 | # if our_pred_idx != prob['answer']:
80 | # print(our_predictions[prob_id]['prompt'])
81 | # print('-----------------')
82 | # print(f'LECTURE: {prob["lecture"]}')
83 | # print(f'SOLUTION: {prob["solution"]}')
84 | # print('=====================')
85 | else:
86 | # continue
87 | pass
88 | # gpt4_pred_idx = our_pred_idx
89 |
90 | if gpt4_pred_idx == prob['answer']:
91 | results['correct'] += 1
92 | else:
93 | results['incorrect'] += 1
94 |
95 |
96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
97 | results['correct_upperbound'] += 1
98 |
99 | correct = results['correct']
100 | total = results['correct'] + results['incorrect']
101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
104 |
105 |
--------------------------------------------------------------------------------
/llava/eval/eval_textvqa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import re
5 |
6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--annotation-file', type=str)
12 | parser.add_argument('--result-file', type=str)
13 | parser.add_argument('--result-dir', type=str)
14 | return parser.parse_args()
15 |
16 |
17 | def prompt_processor(prompt):
18 | if prompt.startswith('OCR tokens: '):
19 | pattern = r"Question: (.*?) Short answer:"
20 | match = re.search(pattern, prompt, re.DOTALL)
21 | question = match.group(1)
22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
23 | if prompt.startswith('Reference OCR token:'):
24 | question = prompt.split('\n')[1]
25 | else:
26 | question = prompt.split('\n')[0]
27 | elif len(prompt.split('\n')) == 2:
28 | question = prompt.split('\n')[0]
29 | else:
30 | assert False
31 |
32 | return question.lower()
33 |
34 |
35 | def eval_single(annotation_file, result_file):
36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0]
37 | print(experiment_name)
38 | annotations = json.load(open(annotation_file))['data']
39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
40 | results = [json.loads(line) for line in open(result_file)]
41 |
42 | pred_list = []
43 | for result in results:
44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
45 | pred_list.append({
46 | "pred_answer": result['text'],
47 | "gt_answers": annotation['answers'],
48 | })
49 |
50 | evaluator = TextVQAAccuracyEvaluator()
51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
52 |
53 |
54 | if __name__ == "__main__":
55 | args = get_args()
56 |
57 | if args.result_file is not None:
58 | eval_single(args.annotation_file, args.result_file)
59 |
60 | if args.result_dir is not None:
61 | for result_file in sorted(os.listdir(args.result_dir)):
62 | if not result_file.endswith('.jsonl'):
63 | print(f'Skipping {result_file}')
64 | continue
65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
66 |
--------------------------------------------------------------------------------
/llava/eval/generate_webpage_data_from_table.py:
--------------------------------------------------------------------------------
1 | """Generate json file for webpage."""
2 | import json
3 | import os
4 | import re
5 |
6 | # models = ['llama', 'alpaca', 'gpt35', 'bard']
7 | models = ['vicuna']
8 |
9 |
10 | def read_jsonl(path: str, key: str=None):
11 | data = []
12 | with open(os.path.expanduser(path)) as f:
13 | for line in f:
14 | if not line:
15 | continue
16 | data.append(json.loads(line))
17 | if key is not None:
18 | data.sort(key=lambda x: x[key])
19 | data = {item[key]: item for item in data}
20 | return data
21 |
22 |
23 | def trim_hanging_lines(s: str, n: int) -> str:
24 | s = s.strip()
25 | for _ in range(n):
26 | s = s.split('\n', 1)[1].strip()
27 | return s
28 |
29 |
30 | if __name__ == '__main__':
31 | questions = read_jsonl('table/question.jsonl', key='question_id')
32 |
33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
39 |
40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')
45 |
46 | records = []
47 | for qid in questions.keys():
48 | r = {
49 | 'id': qid,
50 | 'category': questions[qid]['category'],
51 | 'question': questions[qid]['text'],
52 | 'answers': {
53 | # 'alpaca': alpaca_answers[qid]['text'],
54 | # 'llama': llama_answers[qid]['text'],
55 | # 'bard': bard_answers[qid]['text'],
56 | # 'gpt35': gpt35_answers[qid]['text'],
57 | 'vicuna': vicuna_answers[qid]['text'],
58 | 'ours': ours_answers[qid]['text'],
59 | },
60 | 'evaluations': {
61 | # 'alpaca': review_alpaca[qid]['text'],
62 | # 'llama': review_llama[qid]['text'],
63 | # 'bard': review_bard[qid]['text'],
64 | 'vicuna': review_vicuna[qid]['content'],
65 | # 'gpt35': review_gpt35[qid]['text'],
66 | },
67 | 'scores': {
68 | 'vicuna': review_vicuna[qid]['tuple'],
69 | # 'alpaca': review_alpaca[qid]['score'],
70 | # 'llama': review_llama[qid]['score'],
71 | # 'bard': review_bard[qid]['score'],
72 | # 'gpt35': review_gpt35[qid]['score'],
73 | },
74 | }
75 |
76 | # cleanup data
77 | cleaned_evals = {}
78 | for k, v in r['evaluations'].items():
79 | v = v.strip()
80 | lines = v.split('\n')
81 | # trim the first line if it's a pair of numbers
82 | if re.match(r'\d+[, ]+\d+', lines[0]):
83 | lines = lines[1:]
84 | v = '\n'.join(lines)
85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
86 |
87 | r['evaluations'] = cleaned_evals
88 | records.append(r)
89 |
90 | # Reorder the records, this is optional
91 | for r in records:
92 | if r['id'] <= 20:
93 | r['id'] += 60
94 | else:
95 | r['id'] -= 20
96 | for r in records:
97 | if r['id'] <= 50:
98 | r['id'] += 10
99 | elif 50 < r['id'] <= 60:
100 | r['id'] -= 50
101 | for r in records:
102 | if r['id'] == 7:
103 | r['id'] = 1
104 | elif r['id'] < 7:
105 | r['id'] += 1
106 |
107 | records.sort(key=lambda x: x['id'])
108 |
109 | # Write to file
110 | with open('webpage/data.json', 'w') as f:
111 | json.dump({'questions': records, 'models': models}, f, indent=2)
112 |
--------------------------------------------------------------------------------
/llava/eval/model_qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
3 | import torch
4 | import os
5 | import json
6 | from tqdm import tqdm
7 | import shortuuid
8 |
9 | from llava.conversation import default_conversation
10 | from llava.utils import disable_torch_init
11 |
12 |
13 | @torch.inference_mode()
14 | def eval_model(model_name, questions_file, answers_file):
15 | # Model
16 | disable_torch_init()
17 | model_name = os.path.expanduser(model_name)
18 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
19 | model = AutoModelForCausalLM.from_pretrained(model_name,
20 | torch_dtype=torch.float16).cuda()
21 |
22 |
23 | ques_file = open(os.path.expanduser(questions_file), "r")
24 | ans_file = open(os.path.expanduser(answers_file), "w")
25 | for i, line in enumerate(tqdm(ques_file)):
26 | idx = json.loads(line)["question_id"]
27 | qs = json.loads(line)["text"]
28 | cat = json.loads(line)["category"]
29 | conv = default_conversation.copy()
30 | conv.append_message(conv.roles[0], qs)
31 | prompt = conv.get_prompt()
32 | inputs = tokenizer([prompt])
33 | input_ids = torch.as_tensor(inputs.input_ids).cuda()
34 | output_ids = model.generate(
35 | input_ids,
36 | do_sample=True,
37 | use_cache=True,
38 | temperature=0.7,
39 | max_new_tokens=1024,)
40 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
41 | try:
42 | index = outputs.index(conv.sep, len(prompt))
43 | except ValueError:
44 | outputs += conv.sep
45 | index = outputs.index(conv.sep, len(prompt))
46 |
47 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
48 | ans_id = shortuuid.uuid()
49 | ans_file.write(json.dumps({"question_id": idx,
50 | "text": outputs,
51 | "answer_id": ans_id,
52 | "model_id": model_name,
53 | "metadata": {}}) + "\n")
54 | ans_file.flush()
55 | ans_file.close()
56 |
57 | if __name__ == "__main__":
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
60 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
61 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
62 | args = parser.parse_args()
63 |
64 | eval_model(args.model_name, args.question_file, args.answers_file)
65 |
--------------------------------------------------------------------------------
/llava/eval/model_vqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import shortuuid
7 |
8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9 | from llava.conversation import conv_templates, SeparatorStyle
10 | from llava.model.builder import load_pretrained_model
11 | from llava.utils import disable_torch_init
12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13 |
14 | from PIL import Image
15 | import math
16 |
17 |
18 | def split_list(lst, n):
19 | """Split a list into n (roughly) equal-sized chunks"""
20 | chunk_size = math.ceil(len(lst) / n) # integer division
21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22 |
23 |
24 | def get_chunk(lst, n, k):
25 | chunks = split_list(lst, n)
26 | return chunks[k]
27 |
28 |
29 | def eval_model(args):
30 | # Model
31 | disable_torch_init()
32 | model_path = os.path.expanduser(args.model_path)
33 | model_name = get_model_name_from_path(model_path)
34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35 |
36 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
38 | answers_file = os.path.expanduser(args.answers_file)
39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
40 | ans_file = open(answers_file, "w")
41 | for line in tqdm(questions):
42 | idx = line["question_id"]
43 | image_file = line["image"]
44 | qs = line["text"]
45 | cur_prompt = qs
46 | if model.config.mm_use_im_start_end:
47 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
48 | else:
49 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
50 |
51 | conv = conv_templates[args.conv_mode].copy()
52 | conv.append_message(conv.roles[0], qs)
53 | conv.append_message(conv.roles[1], None)
54 | prompt = conv.get_prompt()
55 |
56 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
57 |
58 | image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB')
59 | image_tensor = process_images([image], image_processor, model.config)[0]
60 |
61 | with torch.inference_mode():
62 | output_ids = model.generate(
63 | input_ids,
64 | images=image_tensor.unsqueeze(0).half().cuda(),
65 | image_sizes=[image.size],
66 | do_sample=True if args.temperature > 0 else False,
67 | temperature=args.temperature,
68 | top_p=args.top_p,
69 | num_beams=args.num_beams,
70 | # no_repeat_ngram_size=3,
71 | max_new_tokens=1024,
72 | use_cache=True)
73 |
74 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
75 |
76 | ans_id = shortuuid.uuid()
77 | ans_file.write(json.dumps({"question_id": idx,
78 | "prompt": cur_prompt,
79 | "text": outputs,
80 | "answer_id": ans_id,
81 | "model_id": model_name,
82 | "metadata": {}}) + "\n")
83 | ans_file.flush()
84 | ans_file.close()
85 |
86 | if __name__ == "__main__":
87 | parser = argparse.ArgumentParser()
88 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
89 | parser.add_argument("--model-base", type=str, default=None)
90 | parser.add_argument("--image-folder", type=str, default="")
91 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
92 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
93 | parser.add_argument("--conv-mode", type=str, default="llava_v1")
94 | parser.add_argument("--num-chunks", type=int, default=1)
95 | parser.add_argument("--chunk-idx", type=int, default=0)
96 | parser.add_argument("--temperature", type=float, default=0.2)
97 | parser.add_argument("--top_p", type=float, default=None)
98 | parser.add_argument("--num_beams", type=int, default=1)
99 | args = parser.parse_args()
100 |
101 | eval_model(args)
102 |
--------------------------------------------------------------------------------
/llava/eval/qa_baseline_gpt35.py:
--------------------------------------------------------------------------------
1 | """Generate answers with GPT-3.5"""
2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work
3 | import argparse
4 | import json
5 | import os
6 | import time
7 | import concurrent.futures
8 |
9 | import openai
10 | import tqdm
11 | import shortuuid
12 |
13 | MODEL = 'gpt-3.5-turbo'
14 | MODEL_ID = 'gpt-3.5-turbo:20230327'
15 |
16 | def get_answer(question_id: int, question: str, max_tokens: int):
17 | ans = {
18 | 'answer_id': shortuuid.uuid(),
19 | 'question_id': question_id,
20 | 'model_id': MODEL_ID,
21 | }
22 | for _ in range(3):
23 | try:
24 | response = openai.ChatCompletion.create(
25 | model=MODEL,
26 | messages=[{
27 | 'role': 'system',
28 | 'content': 'You are a helpful assistant.'
29 | }, {
30 | 'role': 'user',
31 | 'content': question,
32 | }],
33 | max_tokens=max_tokens,
34 | )
35 | ans['text'] = response['choices'][0]['message']['content']
36 | return ans
37 | except Exception as e:
38 | print('[ERROR]', e)
39 | ans['text'] = '#ERROR#'
40 | time.sleep(1)
41 | return ans
42 |
43 |
44 | if __name__ == '__main__':
45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
46 | parser.add_argument('-q', '--question')
47 | parser.add_argument('-o', '--output')
48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
49 | args = parser.parse_args()
50 |
51 | questions_dict = {}
52 | with open(os.path.expanduser(args.question)) as f:
53 | for line in f:
54 | if not line:
55 | continue
56 | q = json.loads(line)
57 | questions_dict[q['question_id']] = q['text']
58 |
59 | answers = []
60 |
61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
62 | futures = []
63 | for qid, question in questions_dict.items():
64 | future = executor.submit(get_answer, qid, question, args.max_tokens)
65 | futures.append(future)
66 |
67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
68 | answers.append(future.result())
69 |
70 | answers.sort(key=lambda x: x['question_id'])
71 |
72 | with open(os.path.expanduser(args.output), 'w') as f:
73 | table = [json.dumps(ans) for ans in answers]
74 | f.write('\n'.join(table))
75 |
--------------------------------------------------------------------------------
/llava/eval/summarize_gpt_review.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 |
7 | import argparse
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
11 | parser.add_argument('-d', '--dir', default=None)
12 | parser.add_argument('-v', '--version', default=None)
13 | parser.add_argument('-s', '--select', nargs='*', default=None)
14 | parser.add_argument('-f', '--files', nargs='*', default=[])
15 | parser.add_argument('-i', '--ignore', nargs='*', default=[])
16 | return parser.parse_args()
17 |
18 |
19 | if __name__ == '__main__':
20 | args = parse_args()
21 |
22 | if args.ignore is not None:
23 | args.ignore = [int(x) for x in args.ignore]
24 |
25 | if len(args.files) > 0:
26 | review_files = args.files
27 | else:
28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
29 |
30 | for review_file in sorted(review_files):
31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
32 | if args.select is not None and any(x not in config for x in args.select):
33 | continue
34 | if '0613' in config:
35 | version = '0613'
36 | else:
37 | version = '0314'
38 | if args.version is not None and args.version != version:
39 | continue
40 | scores = defaultdict(list)
41 | print(config)
42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
43 | for review_str in f:
44 | review = json.loads(review_str)
45 | if review['question_id'] in args.ignore:
46 | continue
47 | if 'category' in review:
48 | scores[review['category']].append(review['tuple'])
49 | scores['all'].append(review['tuple'])
50 | else:
51 | if 'tuple' in review:
52 | scores['all'].append(review['tuple'])
53 | else:
54 | scores['all'].append(review['score'])
55 | for k, v in sorted(scores.items()):
56 | stats = np.asarray(v).mean(0).tolist()
57 | stats = [round(x, 3) for x in stats]
58 | # print(k, stats, round(stats[1]/stats[0]*100, 1))
59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1))
60 | print('=================================')
61 |
--------------------------------------------------------------------------------
/llava/eval/table/model.jsonl:
--------------------------------------------------------------------------------
1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"}
2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"}
3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"}
4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"}
5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"}
6 |
--------------------------------------------------------------------------------
/llava/eval/table/reviewer.jsonl:
--------------------------------------------------------------------------------
1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"}
2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"}
3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
5 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/alpaca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/eval/webpage/figures/alpaca.png
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/bard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/eval/webpage/figures/bard.jpg
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/chatgpt.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/llama.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/eval/webpage/figures/llama.jpg
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/vicuna.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/eval/webpage/figures/vicuna.jpeg
--------------------------------------------------------------------------------
/llava/eval/webpage/styles.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
3 | background-color: #f8f9fa;
4 | }
5 |
6 | .navbar-dark .navbar-nav .nav-link {
7 | color: #f1cf68;
8 | font-size: 1.1rem;
9 | padding: 0.5rem 0.6rem;
10 | }
11 |
12 | .card-header {
13 | font-weight: bold;
14 | }
15 |
16 | .card {
17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
18 | transition: 0.3s;
19 | }
20 |
21 | .card:hover {
22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
23 | }
24 |
25 | button {
26 | transition: background-color 0.3s;
27 | }
28 |
29 | button:hover {
30 | background-color: #007bff;
31 | }
32 |
33 | @media (max-width: 767px) {
34 | .form-row .form-group {
35 | margin-bottom: 10px;
36 | }
37 | }
38 |
39 | /* Extra styles */
40 |
41 | .expandable-card .card-text-container {
42 | max-height: 200px;
43 | overflow-y: hidden;
44 | position: relative;
45 | }
46 |
47 | .expandable-card.expanded .card-text-container {
48 | max-height: none;
49 | }
50 |
51 | .expand-btn {
52 | position: relative;
53 | display: none;
54 | background-color: rgba(255, 255, 255, 0.8);
55 | color: #510c75;
56 | border-color: transparent;
57 | }
58 |
59 | .expand-btn:hover {
60 | background-color: rgba(200, 200, 200, 0.8);
61 | text-decoration: none;
62 | border-color: transparent;
63 | color: #510c75;
64 | }
65 |
66 | .expand-btn:focus {
67 | outline: none;
68 | text-decoration: none;
69 | }
70 |
71 | .expandable-card:not(.expanded) .card-text-container:after {
72 | content: "";
73 | position: absolute;
74 | bottom: 0;
75 | left: 0;
76 | width: 100%;
77 | height: 90px;
78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1));
79 | }
80 |
81 | .expandable-card:not(.expanded) .expand-btn {
82 | margin-top: -40px;
83 | }
84 |
85 | .card-body {
86 | padding-bottom: 5px;
87 | }
88 |
89 | .vertical-flex-layout {
90 | justify-content: center;
91 | align-items: center;
92 | height: 100%;
93 | display: flex;
94 | flex-direction: column;
95 | gap: 5px;
96 | }
97 |
98 | .figure-img {
99 | max-width: 100%;
100 | height: auto;
101 | }
102 |
103 | .adjustable-font-size {
104 | font-size: calc(0.5rem + 2vw);
105 | }
106 |
--------------------------------------------------------------------------------
/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
5 | from .language_model.llava_phi3 import LlavaPhiForCausalLM, LlavaPhiConfig
6 | from .language_model.llava_qwen import LlavaQwen2ForCausalLM
7 | except:
8 | pass
9 |
--------------------------------------------------------------------------------
/llava/model/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/model/__pycache__/llava_arch.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/__pycache__/llava_arch.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava import LlavaLlamaForCausalLM
11 |
12 |
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32 | bparam = base.state_dict()[name]
33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34 |
35 | print("Saving target model")
36 | delta.save_pretrained(target_model_path)
37 | delta_tokenizer.save_pretrained(target_model_path)
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--base-model-path", type=str, required=True)
43 | parser.add_argument("--target-model-path", type=str, required=True)
44 | parser.add_argument("--delta-path", type=str, required=True)
45 |
46 | args = parser.parse_args()
47 |
48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
49 |
--------------------------------------------------------------------------------
/llava/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | from llava.model import *
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def consolidate_ckpt(src_path, dst_path):
14 | print("Loading model")
15 | auto_upgrade(src_path)
16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18 | src_model.save_pretrained(dst_path)
19 | src_tokenizer.save_pretrained(dst_path)
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--src", type=str, required=True)
25 | parser.add_argument("--dst", type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 | consolidate_ckpt(args.src, args.dst)
30 |
--------------------------------------------------------------------------------
/llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/language_model/__pycache__/llava_llama.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/language_model/__pycache__/llava_mistral.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/model/language_model/__pycache__/llava_mpt.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/language_model/__pycache__/llava_mpt.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/model/language_model/__pycache__/llava_phi3.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/language_model/__pycache__/llava_phi3.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/model/language_model/llava_mpt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import Optional, Tuple
17 |
18 | import torch
19 |
20 | from transformers import AutoConfig, AutoModelForCausalLM, \
21 | MptConfig, MptForCausalLM, MptModel
22 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
23 |
24 |
25 | class LlavaMptConfig(MptConfig):
26 | model_type = "llava_mpt"
27 |
28 |
29 | class LlavaMptModel(LlavaMetaModel, MptModel):
30 | config_class = LlavaMptConfig
31 |
32 | def __init__(self, config: MptConfig):
33 | config.hidden_size = config.d_model
34 | super(LlavaMptModel, self).__init__(config)
35 |
36 | def embed_tokens(self, x):
37 | return self.wte(x)
38 |
39 |
40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
41 | config_class = LlavaMptConfig
42 | supports_gradient_checkpointing = True
43 |
44 | def __init__(self, config):
45 | super(MptForCausalLM, self).__init__(config)
46 |
47 | self.transformer = LlavaMptModel(config)
48 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49 |
50 | # Initialize weights and apply final processing
51 | self.post_init()
52 |
53 | def get_model(self):
54 | return self.transformer
55 |
56 | def _set_gradient_checkpointing(self, module, value=False):
57 | if isinstance(module, LlavaMptModel):
58 | module.gradient_checkpointing = value
59 |
60 | def forward(
61 | self,
62 | input_ids: Optional[torch.LongTensor] = None,
63 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
64 | attention_mask: Optional[torch.Tensor] = None,
65 | inputs_embeds: Optional[torch.Tensor] = None,
66 | labels: Optional[torch.Tensor] = None,
67 | use_cache: Optional[bool] = None,
68 | output_attentions: Optional[bool] = None,
69 | output_hidden_states: Optional[bool] = None,
70 | return_dict: Optional[bool] = None,
71 | images=None):
72 |
73 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
74 |
75 | return super().forward(
76 | input_ids,
77 | past_key_values=past_key_values,
78 | attention_mask=attention_mask,
79 | inputs_embeds=inputs_embeds,
80 | labels=labels,
81 | use_cache=use_cache,
82 | output_attentions=output_attentions,
83 | output_hidden_states=output_hidden_states,
84 | return_dict=return_dict,
85 | )
86 |
87 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
88 | images = kwargs.pop("images", None)
89 | _inputs = super().prepare_inputs_for_generation(
90 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
91 | )
92 | _inputs['images'] = images
93 | return _inputs
94 |
95 |
96 | AutoConfig.register("llava_mpt", LlavaMptConfig)
97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
98 |
--------------------------------------------------------------------------------
/llava/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31 | bparam = base.state_dict()[name]
32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/multimodal_encoder/__pycache__/builder.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/multimodal_encoder/__pycache__/clip_encoder.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
3 |
4 |
5 | def build_vision_tower(vision_tower_cfg, **kwargs):
6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7 | is_absolute_path_exists = os.path.exists(vision_tower)
8 | use_s2 = getattr(vision_tower_cfg, 's2', False)
9 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
10 | if use_s2:
11 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
12 | else:
13 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
14 |
15 | raise ValueError(f'Unknown vision tower: {vision_tower}')
16 |
--------------------------------------------------------------------------------
/llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/model/multimodal_projector/__pycache__/builder.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 | def forward(self, x):
29 | x = self.pre_norm(x)
30 | return x + self.proj(x)
31 |
32 |
33 | def build_vision_projector(config, delay_load=False, **kwargs):
34 | projector_type = getattr(config, 'mm_projector_type', 'linear')
35 |
36 | if projector_type == 'linear':
37 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
38 |
39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40 | if mlp_gelu_match:
41 | mlp_depth = int(mlp_gelu_match.group(1))
42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43 | for _ in range(1, mlp_depth):
44 | modules.append(nn.GELU())
45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46 | return nn.Sequential(*modules)
47 |
48 | if projector_type == 'identity':
49 | return IdentityMap()
50 |
51 | raise ValueError(f'Unknown projector type: {projector_type}')
52 |
--------------------------------------------------------------------------------
/llava/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if 'llava' in config and 'llava' not in cfg.model_type:
7 | assert cfg.model_type == 'llama'
8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "llava")
15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
21 |
--------------------------------------------------------------------------------
/llava/peft/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all.
4 |
5 | # coding=utf-8
6 | # Copyright 2023-present the HuggingFace Inc. team.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | __version__ = "0.3.0.dev0"
21 |
22 | from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model
23 | from .peft_model import (
24 | PeftModel,
25 | PeftModelForCausalLM,
26 | PeftModelForSeq2SeqLM,
27 | PeftModelForSequenceClassification,
28 | PeftModelForTokenClassification,
29 | )
30 | from .tuners import (
31 | LoraConfig,
32 | LoraModel,
33 | PrefixEncoder,
34 | PrefixTuningConfig,
35 | PromptEmbedding,
36 | PromptEncoder,
37 | PromptEncoderConfig,
38 | PromptEncoderReparameterizationType,
39 | PromptTuningConfig,
40 | PromptTuningInit,
41 | )
42 | from .utils import (
43 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
44 | PeftConfig,
45 | PeftType,
46 | PromptLearningConfig,
47 | TaskType,
48 | bloom_model_postprocess_past_key_value,
49 | get_peft_model_state_dict,
50 | # prepare_model_for_int8_training,
51 | set_peft_model_state_dict,
52 | shift_tokens_right,
53 | )
54 |
--------------------------------------------------------------------------------
/llava/peft/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/llava/peft/__pycache__/mapping.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/__pycache__/mapping.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/__pycache__/mapping.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/__pycache__/mapping.cpython-39.pyc
--------------------------------------------------------------------------------
/llava/peft/__pycache__/peft_model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/__pycache__/peft_model.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/__pycache__/peft_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/__pycache__/peft_model.cpython-39.pyc
--------------------------------------------------------------------------------
/llava/peft/tuners/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all
4 |
5 | # coding=utf-8
6 | # Copyright 2023-present the HuggingFace Inc. team.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | from .lora import LoraConfig, LoraModel
21 | from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType
22 | from .prefix_tuning import PrefixEncoder, PrefixTuningConfig
23 | from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit
24 |
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/lora.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.139831930933664:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.139831930933664
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140023258402944:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140023258402944
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140104265109632:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140104265109632
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140108281172096:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140108281172096
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140160632344704:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140160632344704
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140480966316160:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140480966316160
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140577961184688:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-310.pyc.140577961184688
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/lora.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora.cpython-39.pyc
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/lora_moe.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/lora_moe.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/p_tuning.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/p_tuning.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/p_tuning.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/p_tuning.cpython-39.pyc
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/prefix_tuning.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/prefix_tuning.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/prefix_tuning.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/prefix_tuning.cpython-39.pyc
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/prompt_tuning.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/prompt_tuning.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/tuners/__pycache__/prompt_tuning.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/tuners/__pycache__/prompt_tuning.cpython-39.pyc
--------------------------------------------------------------------------------
/llava/peft/tuners/prefix_tuning.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
17 | from dataclasses import dataclass, field
18 |
19 | import torch
20 |
21 | from ..utils import PeftType, PromptLearningConfig
22 |
23 |
24 | @dataclass
25 | class PrefixTuningConfig(PromptLearningConfig):
26 | """
27 | This is the configuration class to store the configuration of a [`~peft.PrefixEncoder`].
28 |
29 | Args:
30 | encoder_hidden_size (`int`): The hidden size of the prompt encoder.
31 | prefix_projection (`bool`): Whether to project the prefix embeddings.
32 | """
33 |
34 | encoder_hidden_size: int = field(
35 | default=None,
36 | metadata={"help": "The hidden size of the encoder"},
37 | )
38 | prefix_projection: bool = field(
39 | default=False,
40 | metadata={"help": "Whether to project the prefix tokens"},
41 | )
42 |
43 | def __post_init__(self):
44 | self.peft_type = PeftType.PREFIX_TUNING
45 |
46 |
47 | # Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py
48 | # with some refactor
49 | class PrefixEncoder(torch.nn.Module):
50 | r"""
51 | The torch.nn model to encode the prefix
52 |
53 | Args:
54 | config ([`PrefixTuningConfig`]): The configuration of the prefix encoder.
55 |
56 | Example::
57 |
58 | >>> from peft import PrefixEncoder, PrefixTuningConfig >>> config = PrefixTuningConfig(
59 | peft_type="PREFIX_TUNING", task_type="SEQ_2_SEQ_LM", num_virtual_tokens=20, token_dim=768,
60 | num_transformer_submodules=1, num_attention_heads=12, num_layers=12, encoder_hidden_size=768
61 | )
62 | >>> prefix_encoder = PrefixEncoder(config)
63 |
64 |
65 | **Attributes**:
66 | - **embedding** (`torch.nn.Embedding`) --
67 | The embedding layer of the prefix encoder.
68 | - **transform** (`torch.nn.Sequential`) -- The
69 | two-layer MLP to transform the prefix embeddings if `prefix_projection` is `True`.
70 | - **prefix_projection** (`bool`) -- Whether to project the prefix embeddings.
71 |
72 | Input shape: (batch_size, num_virtual_tokens)
73 |
74 | Output shape: (batch_size, num_virtual_tokens, 2*layers*hidden)
75 | """
76 |
77 | def __init__(self, config):
78 | super().__init__()
79 | self.prefix_projection = config.prefix_projection
80 | token_dim = config.token_dim
81 | num_layers = config.num_layers
82 | encoder_hidden_size = config.encoder_hidden_size
83 | num_virtual_tokens = config.num_virtual_tokens
84 | if self.prefix_projection and not config.inference_mode:
85 | # Use a two-layer MLP to encode the prefix
86 | self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
87 | self.transform = torch.nn.Sequential(
88 | torch.nn.Linear(token_dim, encoder_hidden_size),
89 | torch.nn.Tanh(),
90 | torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim),
91 | )
92 | else:
93 | self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim)
94 |
95 | def forward(self, prefix: torch.Tensor):
96 | if self.prefix_projection:
97 | prefix_tokens = self.embedding(prefix)
98 | past_key_values = self.transform(prefix_tokens)
99 | else:
100 | past_key_values = self.embedding(prefix)
101 | return past_key_values
102 |
--------------------------------------------------------------------------------
/llava/peft/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this
3 | # module, but to preserve other warnings. So, don't check this module at all
4 |
5 | # coding=utf-8
6 | # Copyright 2023-present the HuggingFace Inc. team.
7 | #
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 | #
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 | #
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | from .adapters_utils import CONFIG_NAME, WEIGHTS_NAME
21 | from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType
22 | from .other import (
23 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
24 | _set_trainable,
25 | bloom_model_postprocess_past_key_value,
26 | # prepare_model_for_int8_training,
27 | shift_tokens_right,
28 | transpose,
29 | )
30 | from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict
31 |
--------------------------------------------------------------------------------
/llava/peft/utils/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/utils/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/llava/peft/utils/__pycache__/adapters_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/adapters_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/utils/__pycache__/adapters_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/adapters_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/llava/peft/utils/__pycache__/config.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/config.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/utils/__pycache__/config.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/config.cpython-39.pyc
--------------------------------------------------------------------------------
/llava/peft/utils/__pycache__/other.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/other.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/utils/__pycache__/other.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/other.cpython-39.pyc
--------------------------------------------------------------------------------
/llava/peft/utils/__pycache__/save_and_load.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/save_and_load.cpython-310.pyc
--------------------------------------------------------------------------------
/llava/peft/utils/__pycache__/save_and_load.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/peft/utils/__pycache__/save_and_load.cpython-39.pyc
--------------------------------------------------------------------------------
/llava/peft/utils/adapters_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | WEIGHTS_NAME = "adapter_model.bin"
16 | CONFIG_NAME = "adapter_config.json"
17 |
18 | # TODO: add automapping and superclass here?
19 |
--------------------------------------------------------------------------------
/llava/peft/utils/other.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import torch
17 |
18 |
19 | # needed for prefix-tuning of bloom model
20 | def bloom_model_postprocess_past_key_value(past_key_values):
21 | past_key_values = torch.cat(past_key_values)
22 | total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape
23 | keys = past_key_values[: total_layers // 2]
24 | keys = keys.transpose(2, 3).reshape(
25 | total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens
26 | )
27 | values = past_key_values[total_layers // 2 :]
28 | values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim)
29 |
30 | return tuple(zip(keys, values))
31 |
32 |
33 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = {
34 | "bloom": bloom_model_postprocess_past_key_value,
35 | }
36 |
37 |
38 | # copied from transformers.models.bart.modeling_bart
39 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
40 | """
41 | Shift input ids one token to the right.
42 |
43 | Args:
44 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids
45 | pad_token_id (`int`): The id of the `padding` token.
46 | decoder_start_token_id (`int`): The id of the `start` token.
47 | """
48 | shifted_input_ids = input_ids.new_zeros(input_ids.shape)
49 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
50 | shifted_input_ids[:, 0] = decoder_start_token_id
51 |
52 | if pad_token_id is None:
53 | raise ValueError("self.model.config.pad_token_id has to be defined.")
54 | # replace possible -100 values in labels by `pad_token_id`
55 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
56 |
57 | return shifted_input_ids
58 |
59 |
60 | def _set_trainable(model):
61 | if model.modules_to_save is not None:
62 | for name, param in model.named_parameters():
63 | if any(module_name in name for module_name in model.modules_to_save):
64 | param.requires_grad = True
65 |
66 |
67 | def fsdp_auto_wrap_policy(model):
68 | import functools
69 | import os
70 |
71 | from accelerate import FullyShardedDataParallelPlugin
72 | from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
73 |
74 | from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
75 |
76 | def lambda_policy_fn(module):
77 | if (
78 | len(list(module.named_children())) == 0
79 | and getattr(module, "weight", None) is not None
80 | and module.weight.requires_grad
81 | ):
82 | return True
83 | return False
84 |
85 | lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
86 | transformer_wrap_policy = functools.partial(
87 | transformer_auto_wrap_policy,
88 | transformer_layer_cls=(
89 | PrefixEncoder,
90 | PromptEncoder,
91 | PromptEmbedding,
92 | FullyShardedDataParallelPlugin.get_module_class_from_name(
93 | model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "")
94 | ),
95 | ),
96 | )
97 |
98 | auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
99 | return auto_wrap_policy
100 |
101 |
102 | def transpose(weight, fan_in_fan_out):
103 | return weight.T if fan_in_fan_out else weight
104 |
--------------------------------------------------------------------------------
/llava/peft/utils/save_and_load.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from .config import PeftType
17 |
18 |
19 | def get_peft_model_state_dict(model, state_dict=None):
20 | """
21 | Get the state dict of the Peft model.
22 |
23 | Args:
24 | model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP,
25 | the model should be the underlying model/unwrapped model (i.e. model.module).
26 | state_dict (`dict`, *optional*, defaults to `None`):
27 | The state dict of the model. If not provided, the state dict of the model
28 | will be used.
29 | """
30 | if state_dict is None:
31 | state_dict = model.state_dict()
32 | if model.peft_config.peft_type == PeftType.LORA:
33 | # to_return = lora_state_dict(model, bias=model.peft_config.bias)
34 | # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
35 | # to directly with the state dict which is necessary when using DeepSpeed or FSDP
36 | bias = model.peft_config.bias
37 | if bias == "none":
38 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
39 | elif bias == "all":
40 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
41 | elif bias == "lora_only":
42 | to_return = {}
43 | for k in state_dict:
44 | if "lora_" in k:
45 | to_return[k] = state_dict[k]
46 | bias_name = k.split("lora_")[0] + "bias"
47 | if bias_name in state_dict:
48 | to_return[bias_name] = state_dict[bias_name]
49 | else:
50 | raise NotImplementedError
51 | else:
52 | to_return = {}
53 | if model.peft_config.inference_mode:
54 | prompt_embeddings = model.prompt_encoder.embedding.weight
55 | else:
56 | prompt_embeddings = model.get_prompt_embedding_to_save()
57 | to_return["prompt_embeddings"] = prompt_embeddings
58 | if model.modules_to_save is not None:
59 | for key, value in state_dict.items():
60 | if any(module_name in key for module_name in model.modules_to_save):
61 | to_return[key] = value
62 | return to_return
63 |
64 |
65 | def set_peft_model_state_dict(model, peft_model_state_dict):
66 | """
67 | Set the state dict of the Peft model.
68 |
69 | Args:
70 | model ([`PeftModel`]): The Peft model.
71 | peft_model_state_dict (`dict`): The state dict of the Peft model.
72 | """
73 |
74 | for name, param in model.named_parameters():
75 | if name in peft_model_state_dict.keys():
76 | print(f"Loading LoRA in lora_path, {name}...")
77 |
78 | model.load_state_dict(peft_model_state_dict, strict=False)
79 | return model
80 |
--------------------------------------------------------------------------------
/llava/serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/serve/__init__.py
--------------------------------------------------------------------------------
/llava/serve/examples/extreme_ironing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/serve/examples/extreme_ironing.jpg
--------------------------------------------------------------------------------
/llava/serve/examples/waterview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/llava/serve/examples/waterview.jpg
--------------------------------------------------------------------------------
/llava/serve/register_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | Manually register workers.
3 |
4 | Usage:
5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6 | """
7 |
8 | import argparse
9 |
10 | import requests
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--controller-address", type=str)
15 | parser.add_argument("--worker-name", type=str)
16 | parser.add_argument("--check-heart-beat", action="store_true")
17 | args = parser.parse_args()
18 |
19 | url = args.controller_address + "/register_worker"
20 | data = {
21 | "worker_name": args.worker_name,
22 | "check_heart_beat": args.check_heart_beat,
23 | "worker_status": None,
24 | }
25 | r = requests.post(url, json=data)
26 | assert r.status_code == 200
27 |
--------------------------------------------------------------------------------
/llava/serve/test_message.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | import requests
5 |
6 | from llava.conversation import default_conversation
7 |
8 |
9 | def main():
10 | if args.worker_address:
11 | worker_addr = args.worker_address
12 | else:
13 | controller_addr = args.controller_address
14 | ret = requests.post(controller_addr + "/refresh_all_workers")
15 | ret = requests.post(controller_addr + "/list_models")
16 | models = ret.json()["models"]
17 | models.sort()
18 | print(f"Models: {models}")
19 |
20 | ret = requests.post(controller_addr + "/get_worker_address",
21 | json={"model": args.model_name})
22 | worker_addr = ret.json()["address"]
23 | print(f"worker_addr: {worker_addr}")
24 |
25 | if worker_addr == "":
26 | return
27 |
28 | conv = default_conversation.copy()
29 | conv.append_message(conv.roles[0], args.message)
30 | prompt = conv.get_prompt()
31 |
32 | headers = {"User-Agent": "LLaVA Client"}
33 | pload = {
34 | "model": args.model_name,
35 | "prompt": prompt,
36 | "max_new_tokens": args.max_new_tokens,
37 | "temperature": 0.7,
38 | "stop": conv.sep,
39 | }
40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41 | json=pload, stream=True)
42 |
43 | print(prompt.replace(conv.sep, "\n"), end="")
44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45 | if chunk:
46 | data = json.loads(chunk.decode("utf-8"))
47 | output = data["text"].split(conv.sep)[-1]
48 | print(output, end="\r")
49 | print("")
50 |
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55 | parser.add_argument("--worker-address", type=str)
56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57 | parser.add_argument("--max-new-tokens", type=int, default=32)
58 | parser.add_argument("--message", type=str, default=
59 | "Tell me a story with more than 1000 words.")
60 | args = parser.parse_args()
61 |
62 | main()
63 |
--------------------------------------------------------------------------------
/llava/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from llava.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True, encoding='UTF-8')
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.1
2 | accelerate==0.27.0
3 | bitsandbytes==0.41.0
4 | deepspeed==0.9.5
5 | einops-exts==0.0.4
6 | einops==0.6.1
7 | gradio==3.35.2
8 | gradio_client==0.2.9
9 | httpx==0.24.0
10 | markdown2==2.4.10
11 | numpy==1.26.0
12 | peft==0.4.0
13 | scikit-learn==1.2.2
14 | sentencepiece==0.1.99
15 | shortuuid==1.0.11
16 | timm==0.6.13
17 | torchvision==0.15.2
18 | transformers==4.41
19 | wandb==0.15.12
20 | wavedrom==2.0.3.post3
21 | Pygments==2.16.1
22 | omegaconf
23 | pytorch_lightning==2.1
24 | scikit-image
25 | opencv-python
26 | lpips
--------------------------------------------------------------------------------
/requirements_qwen_2_5.txt:
--------------------------------------------------------------------------------
1 | torch==2.4.0
2 | accelerate==0.27.0
3 | bitsandbytes==0.41.0
4 | deepspeed==0.14.4
5 | einops-exts==0.0.4
6 | einops==0.6.1
7 | gradio==3.35.2
8 | gradio_client==0.2.9
9 | httpx==0.24.0
10 | markdown2==2.4.10
11 | numpy==1.26.0
12 | peft==0.4.0
13 | scikit-learn==1.2.2
14 | sentencepiece==0.1.99
15 | shortuuid==1.0.11
16 | timm==0.6.13
17 | torchvision==0.19.0
18 | transformers==4.49.0
19 | wandb==0.15.12
20 | wavedrom==2.0.3.post3
21 | Pygments==2.16.1
--------------------------------------------------------------------------------
/scripts/convert_gqa_for_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("--src", type=str)
7 | parser.add_argument("--dst", type=str)
8 | args = parser.parse_args()
9 |
10 | all_answers = []
11 | for line_idx, line in enumerate(open(args.src)):
12 | res = json.loads(line)
13 | question_id = res['question_id']
14 | text = res['text'].rstrip('.').lower()
15 | all_answers.append({"questionId": question_id, "prediction": text})
16 |
17 | with open(args.dst, 'w') as f:
18 | json.dump(all_answers, f)
19 |
--------------------------------------------------------------------------------
/scripts/convert_mmbench_for_submission.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import pandas as pd
5 |
6 | def get_args():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--annotation-file", type=str, required=True)
9 | parser.add_argument("--result-dir", type=str, required=True)
10 | parser.add_argument("--upload-dir", type=str, required=True)
11 | parser.add_argument("--experiment", type=str, required=True)
12 |
13 | return parser.parse_args()
14 |
15 | if __name__ == "__main__":
16 | args = get_args()
17 |
18 | df = pd.read_table(args.annotation_file)
19 |
20 | cur_df = df.copy()
21 | cur_df = cur_df.drop(columns=['hint', 'category', 'source', 'image', 'comment', 'l2-category'])
22 | cur_df.insert(6, 'prediction', None)
23 | for pred in open(os.path.join(args.result_dir, f"{args.experiment}.jsonl")):
24 | pred = json.loads(pred)
25 | cur_df.loc[df['index'] == pred['question_id'], 'prediction'] = pred['text']
26 |
27 | cur_df.to_excel(os.path.join(args.upload_dir, f"{args.experiment}.xlsx"), index=False, engine='openpyxl')
28 |
--------------------------------------------------------------------------------
/scripts/convert_mmvet_for_eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | parser = argparse.ArgumentParser()
6 | parser.add_argument("--src", type=str)
7 | parser.add_argument("--dst", type=str)
8 | args = parser.parse_args()
9 |
10 | cur_result = {}
11 |
12 | for line in open(args.src):
13 | data = json.loads(line)
14 | qid = data['question_id']
15 | cur_result[f'v1_{qid}'] = data['text']
16 |
17 | with open(args.dst, 'w') as f:
18 | json.dump(cur_result, f, indent=2)
19 |
--------------------------------------------------------------------------------
/scripts/convert_seed_for_submission.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 |
6 | def get_args():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--annotation-file", type=str)
9 | parser.add_argument("--result-file", type=str)
10 | parser.add_argument("--result-upload-file", type=str)
11 | return parser.parse_args()
12 |
13 |
14 | def eval_single(result_file, eval_only_type=None):
15 | results = {}
16 | for line in open(result_file):
17 | row = json.loads(line)
18 | results[row['question_id']] = row
19 |
20 | type_counts = {}
21 | correct_counts = {}
22 | for question_data in data['questions']:
23 | if eval_only_type is not None and question_data['data_type'] != eval_only_type: continue
24 | data_type = question_data['question_type_id']
25 | type_counts[data_type] = type_counts.get(data_type, 0) + 1
26 | try:
27 | question_id = int(question_data['question_id'])
28 | except:
29 | question_id = question_data['question_id']
30 | if question_id not in results:
31 | correct_counts[data_type] = correct_counts.get(data_type, 0)
32 | continue
33 | row = results[question_id]
34 | if row['text'] == question_data['answer']:
35 | correct_counts[data_type] = correct_counts.get(data_type, 0) + 1
36 |
37 | total_count = 0
38 | total_correct = 0
39 | for data_type in sorted(type_counts.keys()):
40 | accuracy = correct_counts[data_type] / type_counts[data_type] * 100
41 | if eval_only_type is None:
42 | print(f"{ques_type_id_to_name[data_type]}: {accuracy:.2f}%")
43 |
44 | total_count += type_counts[data_type]
45 | total_correct += correct_counts[data_type]
46 |
47 | total_accuracy = total_correct / total_count * 100
48 | if eval_only_type is None:
49 | print(f"Total accuracy: {total_accuracy:.2f}%")
50 | else:
51 | print(f"{eval_only_type} accuracy: {total_accuracy:.2f}%")
52 |
53 | return results
54 |
55 | if __name__ == "__main__":
56 | args = get_args()
57 | data = json.load(open(args.annotation_file))
58 | ques_type_id_to_name = {id:n for n,id in data['question_type'].items()}
59 |
60 | results = eval_single(args.result_file)
61 | eval_single(args.result_file, eval_only_type='image')
62 | eval_single(args.result_file, eval_only_type='video')
63 |
64 | with open(args.result_upload_file, 'w') as fp:
65 | for question in data['questions']:
66 | qid = question['question_id']
67 | if qid in results:
68 | result = results[qid]
69 | else:
70 | result = results[int(qid)]
71 | fp.write(json.dumps({
72 | 'question_id': qid,
73 | 'prediction': result['text']
74 | }) + '\n')
75 |
--------------------------------------------------------------------------------
/scripts/convert_sqa_to_llava.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import fire
4 | import re
5 | from convert_sqa_to_llava_base_prompt import build_prompt_chatbot
6 |
7 |
8 | def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"):
9 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
10 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
11 |
12 | split_problems = build_prompt_chatbot(
13 | problems, split_indices, prompt_format,
14 | use_caption=False, is_test=False)
15 |
16 | target_format = []
17 | for prob_id, (input, output) in split_problems.items():
18 | if input.startswith('Question: '):
19 | input = input.replace('Question: ', '')
20 | if output.startswith('Answer: '):
21 | output = output.replace('Answer: ', '')
22 |
23 | raw_prob_data = problems[prob_id]
24 | if raw_prob_data['image'] is None:
25 | target_format.append({
26 | "id": prob_id,
27 | "conversations": [
28 | {'from': 'human', 'value': f"{input}"},
29 | {'from': 'gpt', 'value': f"{output}"},
30 | ],
31 | })
32 |
33 | else:
34 | target_format.append({
35 | "id": prob_id,
36 | "image": os.path.join(prob_id, raw_prob_data['image']),
37 | "conversations": [
38 | {'from': 'human', 'value': f"{input}\n"},
39 | {'from': 'gpt', 'value': f"{output}"},
40 | ],
41 | })
42 |
43 | print(f'Number of samples: {len(target_format)}')
44 |
45 | with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f:
46 | json.dump(target_format, f, indent=2)
47 |
48 |
49 | def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"):
50 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
51 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
52 |
53 | split_problems = build_prompt_chatbot(
54 | problems, split_indices, prompt_format,
55 | use_caption=False, is_test=False)
56 |
57 | writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w")
58 | for prob_id, (input, output) in split_problems.items():
59 | if input.startswith('Question: '):
60 | input = input.replace('Question: ', '')
61 | if output.startswith('Answer: '):
62 | output = output.replace('Answer: ', '')
63 |
64 | raw_prob_data = problems[prob_id]
65 | if raw_prob_data['image'] is None:
66 | data = {
67 | "id": prob_id,
68 | "instruction": f"{input}",
69 | "output": f"{output}",
70 | }
71 |
72 | else:
73 | data = {
74 | "id": prob_id,
75 | "image": os.path.join(prob_id, raw_prob_data['image']),
76 | "instruction": f"{input}\n",
77 | "output": f"{output}",
78 | }
79 | writer.write(json.dumps(data) + '\n')
80 | writer.close()
81 |
82 |
83 | def main(task, **kwargs):
84 | globals()[task](**kwargs)
85 |
86 |
87 | if __name__ == "__main__":
88 | fire.Fire(main)
89 |
--------------------------------------------------------------------------------
/scripts/convert_vizwiz_for_submission.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 |
5 | from llava.eval.m4c_evaluator import EvalAIAnswerProcessor
6 |
7 |
8 | def parse_args():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--annotation-file', type=str, required=True)
11 | parser.add_argument('--result-file', type=str, required=True)
12 | parser.add_argument('--result-upload-file', type=str, required=True)
13 | return parser.parse_args()
14 |
15 |
16 | if __name__ == '__main__':
17 |
18 | args = parse_args()
19 |
20 | os.makedirs(os.path.dirname(args.result_upload_file), exist_ok=True)
21 |
22 | results = []
23 | error_line = 0
24 | for line_idx, line in enumerate(open(args.result_file)):
25 | try:
26 | results.append(json.loads(line))
27 | except:
28 | error_line += 1
29 | results = {x['question_id']: x['text'] for x in results}
30 | test_split = [json.loads(line) for line in open(args.annotation_file)]
31 | split_ids = set([x['question_id'] for x in test_split])
32 |
33 | print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}')
34 |
35 | all_answers = []
36 |
37 | answer_processor = EvalAIAnswerProcessor()
38 |
39 | for x in test_split:
40 | assert x['question_id'] in results
41 | all_answers.append({
42 | 'image': x['image'],
43 | 'answer': answer_processor(results[x['question_id']])
44 | })
45 |
46 | with open(args.result_upload_file, 'w') as f:
47 | json.dump(all_answers, f)
48 |
--------------------------------------------------------------------------------
/scripts/convert_vqav2_for_submission.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 |
5 | from llava.eval.m4c_evaluator import EvalAIAnswerProcessor
6 |
7 |
8 | def parse_args():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--dir', type=str, default="./playground/data/eval/vqav2")
11 | parser.add_argument('--ckpt', type=str, required=True)
12 | parser.add_argument('--split', type=str, required=True)
13 | return parser.parse_args()
14 |
15 |
16 | if __name__ == '__main__':
17 |
18 | args = parse_args()
19 |
20 | src = os.path.join(args.dir, 'answers', args.split, args.ckpt, 'merge.jsonl')
21 | test_split = os.path.join(args.dir, 'llava_vqav2_mscoco_test2015.jsonl')
22 | dst = os.path.join(args.dir, 'answers_upload', args.split, f'{args.ckpt}.json')
23 | os.makedirs(os.path.dirname(dst), exist_ok=True)
24 |
25 | results = []
26 | error_line = 0
27 | for line_idx, line in enumerate(open(src)):
28 | try:
29 | results.append(json.loads(line))
30 | except:
31 | error_line += 1
32 |
33 | results = {x['question_id']: x['text'] for x in results}
34 | test_split = [json.loads(line) for line in open(test_split)]
35 | split_ids = set([x['question_id'] for x in test_split])
36 |
37 | print(f'total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}')
38 |
39 | all_answers = []
40 |
41 | answer_processor = EvalAIAnswerProcessor()
42 |
43 | for x in test_split:
44 | if x['question_id'] not in results:
45 | all_answers.append({
46 | 'question_id': x['question_id'],
47 | 'answer': ''
48 | })
49 | else:
50 | all_answers.append({
51 | 'question_id': x['question_id'],
52 | 'answer': answer_processor(results[x['question_id']])
53 | })
54 |
55 | with open(dst, 'w') as f:
56 | json.dump(all_answers, open(dst, 'w'))
57 |
--------------------------------------------------------------------------------
/scripts/extract_mm_projector.py:
--------------------------------------------------------------------------------
1 | """
2 | This is just a utility that I use to extract the projector for quantized models.
3 | It is NOT necessary at all to train, or run inference/serve demos.
4 | Use this script ONLY if you fully understand its implications.
5 | """
6 |
7 |
8 | import os
9 | import argparse
10 | import torch
11 | import json
12 | from collections import defaultdict
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description='Extract MMProjector weights')
17 | parser.add_argument('--model-path', type=str, help='model folder')
18 | parser.add_argument('--output', type=str, help='output file')
19 | args = parser.parse_args()
20 | return args
21 |
22 |
23 | if __name__ == '__main__':
24 | args = parse_args()
25 |
26 | keys_to_match = ['mm_projector']
27 | ckpt_to_key = defaultdict(list)
28 | try:
29 | model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json')))
30 | for k, v in model_indices['weight_map'].items():
31 | if any(key_match in k for key_match in keys_to_match):
32 | ckpt_to_key[v].append(k)
33 | except FileNotFoundError:
34 | # Smaller models or model checkpoints saved by DeepSpeed.
35 | v = 'pytorch_model.bin'
36 | for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys():
37 | if any(key_match in k for key_match in keys_to_match):
38 | ckpt_to_key[v].append(k)
39 |
40 | loaded_weights = {}
41 |
42 | for ckpt_name, weight_keys in ckpt_to_key.items():
43 | ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu')
44 | for k in weight_keys:
45 | loaded_weights[k] = ckpt[k]
46 |
47 | torch.save(loaded_weights, args.output)
48 |
--------------------------------------------------------------------------------
/scripts/finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
4 |
5 | # Uncomment and set the following variables correspondingly to run this script:
6 |
7 | ################## VICUNA ##################
8 | # PROMPT_VERSION=v1
9 | # MODEL_VERSION="vicuna-v1-3-7b"
10 | ################## VICUNA ##################
11 |
12 | ################## LLaMA-2 ##################
13 | # PROMPT_VERSION="llava_llama_2"
14 | # MODEL_VERSION="llama-2-7b-chat"
15 | ################## LLaMA-2 ##################
16 |
17 | deepspeed llava/train/train_mem.py \
18 | --deepspeed ./scripts/zero2.json \
19 | --model_name_or_path ./checkpoints/$MODEL_VERSION \
20 | --version $PROMPT_VERSION \
21 | --data_path ./playground/data/llava_instruct_80k.json \
22 | --image_folder /path/to/coco/train2017 \
23 | --vision_tower openai/clip-vit-large-patch14 \
24 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
25 | --mm_vision_select_layer -2 \
26 | --mm_use_im_start_end False \
27 | --mm_use_im_patch_token False \
28 | --bf16 True \
29 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \
30 | --num_train_epochs 1 \
31 | --per_device_train_batch_size 16 \
32 | --per_device_eval_batch_size 4 \
33 | --gradient_accumulation_steps 1 \
34 | --evaluation_strategy "no" \
35 | --save_strategy "steps" \
36 | --save_steps 50000 \
37 | --save_total_limit 1 \
38 | --learning_rate 2e-5 \
39 | --weight_decay 0. \
40 | --warmup_ratio 0.03 \
41 | --lr_scheduler_type "cosine" \
42 | --logging_steps 1 \
43 | --tf32 True \
44 | --model_max_length 2048 \
45 | --gradient_checkpointing True \
46 | --dataloader_num_workers 4 \
47 | --lazy_preprocess True \
48 | --report_to wandb
49 |
--------------------------------------------------------------------------------
/scripts/finetune_full_schedule.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
4 |
5 | # Uncomment and set the following variables correspondingly to run this script:
6 |
7 | ################## VICUNA ##################
8 | # PROMPT_VERSION=v1
9 | # MODEL_VERSION="vicuna-v1-3-7b"
10 | ################## VICUNA ##################
11 |
12 | ################## LLaMA-2 ##################
13 | # PROMPT_VERSION="llava_llama_2"
14 | # MODEL_VERSION="llama-2-7b-chat"
15 | ################## LLaMA-2 ##################
16 |
17 | deepspeed llava/train/train_mem.py \
18 | --deepspeed ./scripts/zero2.json \
19 | --model_name_or_path ./checkpoints/$MODEL_VERSION \
20 | --version $PROMPT_VERSION \
21 | --data_path ./playground/data/llava_instruct_158k.json \
22 | --image_folder /path/to/coco/train2017 \
23 | --vision_tower openai/clip-vit-large-patch14 \
24 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
25 | --mm_vision_select_layer -2 \
26 | --mm_use_im_start_end False \
27 | --mm_use_im_patch_token False \
28 | --bf16 True \
29 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune \
30 | --num_train_epochs 3 \
31 | --per_device_train_batch_size 16 \
32 | --per_device_eval_batch_size 4 \
33 | --gradient_accumulation_steps 1 \
34 | --evaluation_strategy "no" \
35 | --save_strategy "steps" \
36 | --save_steps 50000 \
37 | --save_total_limit 1 \
38 | --learning_rate 2e-5 \
39 | --weight_decay 0. \
40 | --warmup_ratio 0.03 \
41 | --lr_scheduler_type "cosine" \
42 | --logging_steps 1 \
43 | --tf32 True \
44 | --model_max_length 2048 \
45 | --gradient_checkpointing True \
46 | --dataloader_num_workers 4 \
47 | --lazy_preprocess True \
48 | --report_to wandb
49 |
--------------------------------------------------------------------------------
/scripts/finetune_lora.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
4 |
5 | # Uncomment and set the following variables correspondingly to run this script:
6 |
7 | ################## VICUNA ##################
8 | # PROMPT_VERSION=v1
9 | # MODEL_VERSION="vicuna-v1-3-7b"
10 | ################## VICUNA ##################
11 |
12 | ################## LLaMA-2 ##################
13 | # PROMPT_VERSION="llava_llama_2"
14 | # MODEL_VERSION="llama-2-7b-chat"
15 | ################## LLaMA-2 ##################
16 |
17 | deepspeed llava/train/train_mem.py \
18 | --deepspeed ./scripts/zero2.json \
19 | --lora_enable True \
20 | --model_name_or_path ./checkpoints/$MODEL_VERSION \
21 | --version $PROMPT_VERSION \
22 | --data_path ./playground/data/llava_instruct_80k.json \
23 | --image_folder /path/to/coco/train2017 \
24 | --vision_tower openai/clip-vit-large-patch14 \
25 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
26 | --mm_vision_select_layer -2 \
27 | --mm_use_im_start_end False \
28 | --mm_use_im_patch_token False \
29 | --bf16 True \
30 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \
31 | --num_train_epochs 1 \
32 | --per_device_train_batch_size 16 \
33 | --per_device_eval_batch_size 4 \
34 | --gradient_accumulation_steps 1 \
35 | --evaluation_strategy "no" \
36 | --save_strategy "steps" \
37 | --save_steps 50000 \
38 | --save_total_limit 1 \
39 | --learning_rate 2e-5 \
40 | --weight_decay 0. \
41 | --warmup_ratio 0.03 \
42 | --lr_scheduler_type "cosine" \
43 | --logging_steps 1 \
44 | --tf32 True \
45 | --model_max_length 2048 \
46 | --gradient_checkpointing True \
47 | --lazy_preprocess True \
48 | --dataloader_num_workers 4 \
49 | --report_to wandb
50 |
--------------------------------------------------------------------------------
/scripts/finetune_qlora.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
4 |
5 | # Uncomment and set the following variables correspondingly to run this script:
6 |
7 | ################## VICUNA ##################
8 | # PROMPT_VERSION=v1
9 | # MODEL_VERSION="vicuna-v1-3-7b"
10 | ################## VICUNA ##################
11 |
12 | ################## LLaMA-2 ##################
13 | # PROMPT_VERSION="llava_llama_2"
14 | # MODEL_VERSION="llama-2-7b-chat"
15 | ################## LLaMA-2 ##################
16 |
17 | deepspeed llava/train/train_mem.py \
18 | --deepspeed ./scripts/zero2.json \
19 | --lora_enable True \
20 | --bits 4 \
21 | --model_name_or_path ./checkpoints/$MODEL_VERSION \
22 | --version $PROMPT_VERSION \
23 | --data_path ./playground/data/llava_instruct_80k.json \
24 | --image_folder /path/to/coco/train2017 \
25 | --vision_tower openai/clip-vit-large-patch14 \
26 | --pretrain_mm_mlp_adapter ./checkpoints/llava-$MODEL_VERSION-pretrain/mm_projector.bin \
27 | --mm_vision_select_layer -2 \
28 | --mm_use_im_start_end False \
29 | --mm_use_im_patch_token False \
30 | --bf16 True \
31 | --output_dir ./checkpoints/llava-$MODEL_VERSION-finetune_lora \
32 | --num_train_epochs 1 \
33 | --per_device_train_batch_size 16 \
34 | --per_device_eval_batch_size 4 \
35 | --gradient_accumulation_steps 1 \
36 | --evaluation_strategy "no" \
37 | --save_strategy "steps" \
38 | --save_steps 50000 \
39 | --save_total_limit 1 \
40 | --learning_rate 2e-5 \
41 | --weight_decay 0. \
42 | --warmup_ratio 0.03 \
43 | --lr_scheduler_type "cosine" \
44 | --logging_steps 1 \
45 | --tf32 True \
46 | --model_max_length 2048 \
47 | --gradient_checkpointing True \
48 | --lazy_preprocess True \
49 | --dataloader_num_workers 4 \
50 | --report_to wandb
51 |
--------------------------------------------------------------------------------
/scripts/finetune_sqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
4 |
5 | deepspeed llava/train/train_mem.py \
6 | --deepspeed ./scripts/zero2.json \
7 | --model_name_or_path lmsys/vicuna-13b-v1.3 \
8 | --version $PROMPT_VERSION \
9 | --data_path /Data/ScienceQA/data/scienceqa/llava_train_QCM-LEA.json \
10 | --image_folder /Data/ScienceQA/data/scienceqa/images/train \
11 | --vision_tower openai/clip-vit-large-patch14 \
12 | --pretrain_mm_mlp_adapter ./checkpoints/huggingface/liuhaotian/llava-pretrain-vicuna-13b-v1.3/mm_projector.bin \
13 | --mm_vision_select_layer -2 \
14 | --mm_use_im_start_end False \
15 | --mm_use_im_patch_token False \
16 | --bf16 True \
17 | --output_dir ./checkpoints/llava-vicuna-13b-v1.3-pretrain_lcs558k_plain-ScienceQA_QCM_LEA-12e \
18 | --num_train_epochs 12 \
19 | --per_device_train_batch_size 16 \
20 | --per_device_eval_batch_size 4 \
21 | --gradient_accumulation_steps 1 \
22 | --evaluation_strategy "no" \
23 | --save_strategy "steps" \
24 | --save_steps 50000 \
25 | --save_total_limit 1 \
26 | --learning_rate 2e-5 \
27 | --weight_decay 0. \
28 | --warmup_ratio 0.03 \
29 | --lr_scheduler_type "cosine" \
30 | --logging_steps 1 \
31 | --tf32 True \
32 | --model_max_length 2048 \
33 | --gradient_checkpointing True \
34 | --dataloader_num_workers 4 \
35 | --lazy_preprocess True \
36 | --report_to wandb
37 |
--------------------------------------------------------------------------------
/scripts/merge_lora_weights.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from llava.model.builder import load_pretrained_model
3 | from llava.mm_utils import get_model_name_from_path
4 |
5 |
6 | def merge_lora(args):
7 | model_name = get_model_name_from_path(args.model_path)
8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu')
9 |
10 | model.save_pretrained(args.save_model_path)
11 | tokenizer.save_pretrained(args.save_model_path)
12 |
13 |
14 | if __name__ == "__main__":
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--model-path", type=str, required=True)
17 | parser.add_argument("--model-base", type=str, required=True)
18 | parser.add_argument("--save-model-path", type=str, required=True)
19 |
20 | args = parser.parse_args()
21 |
22 | merge_lora(args)
23 |
--------------------------------------------------------------------------------
/scripts/pretrain.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # IMPORTANT: this is the training script for the original LLaVA, NOT FOR LLaVA V1.5!
4 |
5 | # Uncomment and set the following variables correspondingly to run this script:
6 |
7 | # MODEL_VERSION=vicuna-v1-3-7b
8 | # MODEL_VERSION=llama-2-7b-chat
9 |
10 | ########### DO NOT CHANGE ###########
11 | ########### USE THIS FOR BOTH ###########
12 | PROMPT_VERSION=plain
13 | ########### DO NOT CHANGE ###########
14 |
15 | deepspeed llava/train/train_mem.py \
16 | --deepspeed ./scripts/zero2.json \
17 | --model_name_or_path ./checkpoints/$MODEL_VERSION \
18 | --version $PROMPT_VERSION \
19 | --data_path /path/to/pretrain_data.json \
20 | --image_folder /path/to/images \
21 | --vision_tower openai/clip-vit-large-patch14 \
22 | --tune_mm_mlp_adapter True \
23 | --mm_vision_select_layer -2 \
24 | --mm_use_im_start_end False \
25 | --mm_use_im_patch_token False \
26 | --bf16 True \
27 | --output_dir ./checkpoints/llava-$MODEL_VERSION-pretrain \
28 | --num_train_epochs 1 \
29 | --per_device_train_batch_size 16 \
30 | --per_device_eval_batch_size 4 \
31 | --gradient_accumulation_steps 1 \
32 | --evaluation_strategy "no" \
33 | --save_strategy "steps" \
34 | --save_steps 24000 \
35 | --save_total_limit 1 \
36 | --learning_rate 2e-3 \
37 | --weight_decay 0. \
38 | --warmup_ratio 0.03 \
39 | --lr_scheduler_type "cosine" \
40 | --logging_steps 1 \
41 | --tf32 True \
42 | --model_max_length 2048 \
43 | --gradient_checkpointing True \
44 | --dataloader_num_workers 4 \
45 | --lazy_preprocess True \
46 | --report_to wandb
47 |
--------------------------------------------------------------------------------
/scripts/pretrain_xformers.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Uncomment and set the following variables correspondingly to run this script:
4 |
5 | # MODEL_VERSION=vicuna-v1-3-7b
6 | # MODEL_VERSION=llama-2-7b-chat
7 |
8 | ########### DO NOT CHANGE ###########
9 | ########### USE THIS FOR BOTH ###########
10 | PROMPT_VERSION=plain
11 | ########### DO NOT CHANGE ###########
12 |
13 | deepspeed llava/train/train_xformers.py \
14 | --deepspeed ./scripts/zero2.json \
15 | --model_name_or_path ./checkpoints/$MODEL_VERSION \
16 | --version $PROMPT_VERSION \
17 | --data_path /path/to/pretrain_data.json \
18 | --image_folder /path/to/images \
19 | --vision_tower openai/clip-vit-large-patch14 \
20 | --tune_mm_mlp_adapter True \
21 | --mm_vision_select_layer -2 \
22 | --mm_use_im_start_end False \
23 | --mm_use_im_patch_token False \
24 | --bf16 False \
25 | --output_dir ./checkpoints/llava-$MODEL_VERSION-pretrain \
26 | --num_train_epochs 1 \
27 | --per_device_train_batch_size 4 \
28 | --per_device_eval_batch_size 4 \
29 | --gradient_accumulation_steps 4 \
30 | --evaluation_strategy "no" \
31 | --save_strategy "steps" \
32 | --save_steps 24000 \
33 | --save_total_limit 1 \
34 | --learning_rate 2e-3 \
35 | --weight_decay 0. \
36 | --warmup_ratio 0.03 \
37 | --lr_scheduler_type "cosine" \
38 | --logging_steps 1 \
39 | --tf32 False \
40 | --model_max_length 2048 \
41 | --gradient_checkpointing True \
42 | --dataloader_num_workers 4 \
43 | --lazy_preprocess True \
44 | --report_to wandb
45 |
--------------------------------------------------------------------------------
/scripts/sqa_eval_batch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CHUNKS=8
4 | for IDX in {0..7}; do
5 | CUDA_VISIBLE_DEVICES=$IDX python -m llava.eval.model_vqa_science \
6 | --model-path liuhaotian/llava-lcs558k-scienceqa-vicuna-13b-v1.3 \
7 | --question-file ~/haotian/datasets/ScienceQA/data/scienceqa/llava_test_QCM-LEA.json \
8 | --image-folder ~/haotian/datasets/ScienceQA/data/scienceqa/images/test \
9 | --answers-file ./test_llava-13b-chunk$CHUNKS_$IDX.jsonl \
10 | --num-chunks $CHUNKS \
11 | --chunk-idx $IDX \
12 | --conv-mode llava_v1 &
13 | done
14 |
--------------------------------------------------------------------------------
/scripts/sqa_eval_gather.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | CHUNKS=8
4 | output_file="test_llava-13b.jsonl"
5 |
6 | # Clear out the output file if it exists.
7 | > "$output_file"
8 |
9 | # Loop through the indices and concatenate each file.
10 | for idx in $(seq 0 $((CHUNKS-1))); do
11 | cat "./test_llava-13b-chunk${idx}.jsonl" >> "$output_file"
12 | done
13 |
14 | python llava/eval/eval_science_qa.py \
15 | --base-dir ~/haotian/datasets/ScienceQA/data/scienceqa \
16 | --result-file ./test_llava-13b.jsonl \
17 | --output-file ./test_llava-13b_output.json \
18 | --output-result ./test_llava-13b_result.json
19 |
--------------------------------------------------------------------------------
/scripts/upload_pypi.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Step 0: Clean up
4 | rm -rf dist
5 |
6 | # Step 1: Change the package name to "llava-torch"
7 | sed -i 's/name = "llava"/name = "llava-torch"/' pyproject.toml
8 |
9 | # Step 2: Build the package
10 | python -m build
11 |
12 | # Step 3: Revert the changes in pyproject.toml to the original
13 | sed -i 's/name = "llava-torch"/name = "llava"/' pyproject.toml
14 |
15 | # Step 4: Upload to PyPI
16 | python -m twine upload dist/*
17 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/gqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
4 | IFS=',' read -ra GPULIST <<< "$gpu_list"
5 |
6 | CHUNKS=${#GPULIST[@]}
7 |
8 | CKPT="llava-v1.5-13b"
9 | SPLIT="llava_gqa_testdev_balanced"
10 | GQADIR="./playground/data/eval/gqa/data"
11 |
12 | for IDX in $(seq 0 $((CHUNKS-1))); do
13 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \
14 | --model-path liuhaotian/llava-v1.5-13b \
15 | --question-file ./playground/data/eval/gqa/$SPLIT.jsonl \
16 | --image-folder ./playground/data/eval/gqa/data/images \
17 | --answers-file ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \
18 | --num-chunks $CHUNKS \
19 | --chunk-idx $IDX \
20 | --temperature 0 \
21 | --conv-mode vicuna_v1 &
22 | done
23 |
24 | wait
25 |
26 | output_file=./playground/data/eval/gqa/answers/$SPLIT/$CKPT/merge.jsonl
27 |
28 | # Clear out the output file if it exists.
29 | > "$output_file"
30 |
31 | # Loop through the indices and concatenate each file.
32 | for IDX in $(seq 0 $((CHUNKS-1))); do
33 | cat ./playground/data/eval/gqa/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file"
34 | done
35 |
36 | python scripts/convert_gqa_for_eval.py --src $output_file --dst $GQADIR/testdev_balanced_predictions.json
37 |
38 | cd $GQADIR
39 | python eval/eval.py --tier testdev_balanced
40 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/llavabench.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m llava.eval.model_vqa \
4 | --model-path liuhaotian/llava-v1.5-13b \
5 | --question-file ./playground/data/eval/llava-bench-in-the-wild/questions.jsonl \
6 | --image-folder ./playground/data/eval/llava-bench-in-the-wild/images \
7 | --answers-file ./playground/data/eval/llava-bench-in-the-wild/answers/llava-v1.5-13b.jsonl \
8 | --temperature 0 \
9 | --conv-mode vicuna_v1
10 |
11 | mkdir -p playground/data/eval/llava-bench-in-the-wild/reviews
12 |
13 | python llava/eval/eval_gpt_review_bench.py \
14 | --question playground/data/eval/llava-bench-in-the-wild/questions.jsonl \
15 | --context playground/data/eval/llava-bench-in-the-wild/context.jsonl \
16 | --rule llava/eval/table/rule.json \
17 | --answer-list \
18 | playground/data/eval/llava-bench-in-the-wild/answers_gpt4.jsonl \
19 | playground/data/eval/llava-bench-in-the-wild/answers/llava-v1.5-13b.jsonl \
20 | --output \
21 | playground/data/eval/llava-bench-in-the-wild/reviews/llava-v1.5-13b.jsonl
22 |
23 | python llava/eval/summarize_gpt_review.py -f playground/data/eval/llava-bench-in-the-wild/reviews/llava-v1.5-13b.jsonl
24 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/mmbench.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SPLIT="mmbench_dev_20230712"
4 |
5 | python -m llava.eval.model_vqa_mmbench \
6 | --model-path liuhaotian/llava-v1.5-13b \
7 | --question-file ./playground/data/eval/mmbench/$SPLIT.tsv \
8 | --answers-file ./playground/data/eval/mmbench/answers/$SPLIT/llava-v1.5-13b.jsonl \
9 | --single-pred-prompt \
10 | --temperature 0 \
11 | --conv-mode vicuna_v1
12 |
13 | mkdir -p playground/data/eval/mmbench/answers_upload/$SPLIT
14 |
15 | python scripts/convert_mmbench_for_submission.py \
16 | --annotation-file ./playground/data/eval/mmbench/$SPLIT.tsv \
17 | --result-dir ./playground/data/eval/mmbench/answers/$SPLIT \
18 | --upload-dir ./playground/data/eval/mmbench/answers_upload/$SPLIT \
19 | --experiment llava-v1.5-13b
20 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/mmbench_cn.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SPLIT="mmbench_dev_cn_20231003"
4 |
5 | python -m llava.eval.model_vqa_mmbench \
6 | --model-path liuhaotian/llava-v1.5-13b \
7 | --question-file ./playground/data/eval/mmbench_cn/$SPLIT.tsv \
8 | --answers-file ./playground/data/eval/mmbench_cn/answers/$SPLIT/llava-v1.5-13b.jsonl \
9 | --lang cn \
10 | --single-pred-prompt \
11 | --temperature 0 \
12 | --conv-mode vicuna_v1
13 |
14 | mkdir -p playground/data/eval/mmbench/answers_upload/$SPLIT
15 |
16 | python scripts/convert_mmbench_for_submission.py \
17 | --annotation-file ./playground/data/eval/mmbench_cn/$SPLIT.tsv \
18 | --result-dir ./playground/data/eval/mmbench_cn/answers/$SPLIT \
19 | --upload-dir ./playground/data/eval/mmbench_cn/answers_upload/$SPLIT \
20 | --experiment llava-v1.5-13b
21 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/mme.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m llava.eval.model_vqa_loader \
4 | --model-path liuhaotian/llava-v1.5-13b \
5 | --question-file ./playground/data/eval/MME/llava_mme.jsonl \
6 | --image-folder ./playground/data/eval/MME/MME_Benchmark_release_version \
7 | --answers-file ./playground/data/eval/MME/answers/llava-v1.5-13b.jsonl \
8 | --temperature 0 \
9 | --conv-mode vicuna_v1
10 |
11 | cd ./playground/data/eval/MME
12 |
13 | python convert_answer_to_mme.py --experiment llava-v1.5-13b
14 |
15 | cd eval_tool
16 |
17 | python calculation.py --results_dir answers/llava-v1.5-13b
18 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/mmvet.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m llava.eval.model_vqa \
4 | --model-path liuhaotian/llava-v1.5-13b \
5 | --question-file ./playground/data/eval/mm-vet/llava-mm-vet.jsonl \
6 | --image-folder ./playground/data/eval/mm-vet/images \
7 | --answers-file ./playground/data/eval/mm-vet/answers/llava-v1.5-13b.jsonl \
8 | --temperature 0 \
9 | --conv-mode vicuna_v1
10 |
11 | mkdir -p ./playground/data/eval/mm-vet/results
12 |
13 | python scripts/convert_mmvet_for_eval.py \
14 | --src ./playground/data/eval/mm-vet/answers/llava-v1.5-13b.jsonl \
15 | --dst ./playground/data/eval/mm-vet/results/llava-v1.5-13b.json
16 |
17 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/pope.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m llava.eval.model_vqa_loader \
4 | --model-path liuhaotian/llava-v1.5-13b \
5 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \
6 | --image-folder ./playground/data/eval/pope/val2014 \
7 | --answers-file ./playground/data/eval/pope/answers/llava-v1.5-13b.jsonl \
8 | --temperature 0 \
9 | --conv-mode vicuna_v1
10 |
11 | python llava/eval/eval_pope.py \
12 | --annotation-dir ./playground/data/eval/pope/coco \
13 | --question-file ./playground/data/eval/pope/llava_pope_test.jsonl \
14 | --result-file ./playground/data/eval/pope/answers/llava-v1.5-13b.jsonl
15 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/qbench.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ "$1" = "dev" ]; then
4 | echo "Evaluating in 'dev' split."
5 | elif [ "$1" = "test" ]; then
6 | echo "Evaluating in 'test' split."
7 | else
8 | echo "Unknown split, please choose between 'dev' and 'test'."
9 | exit 1
10 | fi
11 |
12 | python -m llava.eval.model_vqa_qbench \
13 | --model-path liuhaotian/llava-v1.5-13b \
14 | --image-folder ./playground/data/eval/qbench/images_llvisionqa/ \
15 | --questions-file ./playground/data/eval/qbench/llvisionqa_$1.json \
16 | --answers-file ./playground/data/eval/qbench/llvisionqa_$1_answers.jsonl \
17 | --conv-mode llava_v1 \
18 | --lang en
19 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/qbench_zh.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | if [ "$1" = "dev" ]; then
4 | ZH_SPLIT="验证集"
5 | echo "Evaluating in 'dev' split."
6 | elif [ "$1" = "test" ]; then
7 | ZH_SPLIT="测试集"
8 | echo "Evaluating in 'test' split."
9 | else
10 | echo "Unknown split, please choose between 'dev' and 'test'."
11 | exit 1
12 | fi
13 |
14 | python -m llava.eval.model_vqa_qbench \
15 | --model-path liuhaotian/llava-v1.5-13b \
16 | --image-folder ./playground/data/eval/qbench/images_llvisionqa/ \
17 | --questions-file ./playground/data/eval/qbench/质衡-问答-$ZH_SPLIT.json \
18 | --answers-file ./playground/data/eval/qbench/llvisionqa_zh_$1_answers.jsonl \
19 | --conv-mode llava_v1 \
20 | --lang zh
21 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/seed.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
4 | IFS=',' read -ra GPULIST <<< "$gpu_list"
5 |
6 | CHUNKS=${#GPULIST[@]}
7 |
8 | CKPT="llava-v1.5-13b"
9 |
10 | for IDX in $(seq 0 $((CHUNKS-1))); do
11 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \
12 | --model-path liuhaotian/llava-v1.5-13b \
13 | --question-file ./playground/data/eval/seed_bench/llava-seed-bench.jsonl \
14 | --image-folder ./playground/data/eval/seed_bench \
15 | --answers-file ./playground/data/eval/seed_bench/answers/$CKPT/${CHUNKS}_${IDX}.jsonl \
16 | --num-chunks $CHUNKS \
17 | --chunk-idx $IDX \
18 | --temperature 0 \
19 | --conv-mode vicuna_v1 &
20 | done
21 |
22 | wait
23 |
24 | output_file=./playground/data/eval/seed_bench/answers/$CKPT/merge.jsonl
25 |
26 | # Clear out the output file if it exists.
27 | > "$output_file"
28 |
29 | # Loop through the indices and concatenate each file.
30 | for IDX in $(seq 0 $((CHUNKS-1))); do
31 | cat ./playground/data/eval/seed_bench/answers/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file"
32 | done
33 |
34 | # Evaluate
35 | python scripts/convert_seed_for_submission.py \
36 | --annotation-file ./playground/data/eval/seed_bench/SEED-Bench.json \
37 | --result-file $output_file \
38 | --result-upload-file ./playground/data/eval/seed_bench/answers_upload/llava-v1.5-13b.jsonl
39 |
40 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/sqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m llava.eval.model_vqa_science \
4 | --model-path liuhaotian/llava-v1.5-13b \
5 | --question-file ./playground/data/eval/scienceqa/llava_test_CQM-A.json \
6 | --image-folder ./playground/data/eval/scienceqa/images/test \
7 | --answers-file ./playground/data/eval/scienceqa/answers/llava-v1.5-13b.jsonl \
8 | --single-pred-prompt \
9 | --temperature 0 \
10 | --conv-mode vicuna_v1
11 |
12 | python llava/eval/eval_science_qa.py \
13 | --base-dir ./playground/data/eval/scienceqa \
14 | --result-file ./playground/data/eval/scienceqa/answers/llava-v1.5-13b.jsonl \
15 | --output-file ./playground/data/eval/scienceqa/answers/llava-v1.5-13b_output.jsonl \
16 | --output-result ./playground/data/eval/scienceqa/answers/llava-v1.5-13b_result.json
17 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/textvqa.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m llava.eval.model_vqa_loader \
4 | --model-path liuhaotian/llava-v1.5-13b \
5 | --question-file ./playground/data/eval/textvqa/llava_textvqa_val_v051_ocr.jsonl \
6 | --image-folder ./playground/data/eval/textvqa/train_images \
7 | --answers-file ./playground/data/eval/textvqa/answers/llava-v1.5-13b.jsonl \
8 | --temperature 0 \
9 | --conv-mode vicuna_v1
10 |
11 | python -m llava.eval.eval_textvqa \
12 | --annotation-file ./playground/data/eval/textvqa/TextVQA_0.5.1_val.json \
13 | --result-file ./playground/data/eval/textvqa/answers/llava-v1.5-13b.jsonl
14 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/vizwiz.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | python -m llava.eval.model_vqa_loader \
4 | --model-path liuhaotian/llava-v1.5-13b \
5 | --question-file ./playground/data/eval/vizwiz/llava_test.jsonl \
6 | --image-folder ./playground/data/eval/vizwiz/test \
7 | --answers-file ./playground/data/eval/vizwiz/answers/llava-v1.5-13b.jsonl \
8 | --temperature 0 \
9 | --conv-mode vicuna_v1
10 |
11 | python scripts/convert_vizwiz_for_submission.py \
12 | --annotation-file ./playground/data/eval/vizwiz/llava_test.jsonl \
13 | --result-file ./playground/data/eval/vizwiz/answers/llava-v1.5-13b.jsonl \
14 | --result-upload-file ./playground/data/eval/vizwiz/answers_upload/llava-v1.5-13b.json
15 |
--------------------------------------------------------------------------------
/scripts/v1_5/eval/vqav2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
4 | IFS=',' read -ra GPULIST <<< "$gpu_list"
5 |
6 | CHUNKS=${#GPULIST[@]}
7 |
8 | CKPT="llava-v1.5-13b"
9 | SPLIT="llava_vqav2_mscoco_test-dev2015"
10 |
11 | for IDX in $(seq 0 $((CHUNKS-1))); do
12 | CUDA_VISIBLE_DEVICES=${GPULIST[$IDX]} python -m llava.eval.model_vqa_loader \
13 | --model-path liuhaotian/llava-v1.5-13b \
14 | --question-file ./playground/data/eval/vqav2/$SPLIT.jsonl \
15 | --image-folder ./playground/data/eval/vqav2/test2015 \
16 | --answers-file ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl \
17 | --num-chunks $CHUNKS \
18 | --chunk-idx $IDX \
19 | --temperature 0 \
20 | --conv-mode vicuna_v1 &
21 | done
22 |
23 | wait
24 |
25 | output_file=./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/merge.jsonl
26 |
27 | # Clear out the output file if it exists.
28 | > "$output_file"
29 |
30 | # Loop through the indices and concatenate each file.
31 | for IDX in $(seq 0 $((CHUNKS-1))); do
32 | cat ./playground/data/eval/vqav2/answers/$SPLIT/$CKPT/${CHUNKS}_${IDX}.jsonl >> "$output_file"
33 | done
34 |
35 | python scripts/convert_vqav2_for_submission.py --split $SPLIT --ckpt $CKPT
36 |
37 |
--------------------------------------------------------------------------------
/scripts/v1_5/finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | deepspeed llava/train/train_mem.py \
4 | --deepspeed ./scripts/zero3.json \
5 | --model_name_or_path lmsys/vicuna-13b-v1.5 \
6 | --version v1 \
7 | --data_path ./playground/data/llava_v1_5_mix665k.json \
8 | --image_folder ./playground/data \
9 | --vision_tower openai/clip-vit-large-patch14-336 \
10 | --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \
11 | --mm_projector_type mlp2x_gelu \
12 | --mm_vision_select_layer -2 \
13 | --mm_use_im_start_end False \
14 | --mm_use_im_patch_token False \
15 | --image_aspect_ratio pad \
16 | --group_by_modality_length True \
17 | --bf16 True \
18 | --output_dir ./checkpoints/llava-v1.5-13b \
19 | --num_train_epochs 1 \
20 | --per_device_train_batch_size 16 \
21 | --per_device_eval_batch_size 4 \
22 | --gradient_accumulation_steps 1 \
23 | --evaluation_strategy "no" \
24 | --save_strategy "steps" \
25 | --save_steps 50000 \
26 | --save_total_limit 1 \
27 | --learning_rate 2e-5 \
28 | --weight_decay 0. \
29 | --warmup_ratio 0.03 \
30 | --lr_scheduler_type "cosine" \
31 | --logging_steps 1 \
32 | --tf32 True \
33 | --model_max_length 2048 \
34 | --gradient_checkpointing True \
35 | --dataloader_num_workers 4 \
36 | --lazy_preprocess True \
37 | --report_to wandb
38 |
--------------------------------------------------------------------------------
/scripts/v1_5/finetune_lora.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | deepspeed llava/train/train_mem.py \
4 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
5 | --deepspeed ./scripts/zero3.json \
6 | --model_name_or_path lmsys/vicuna-13b-v1.5 \
7 | --version v1 \
8 | --data_path ./playground/data/llava_v1_5_mix665k.json \
9 | --image_folder ./playground/data \
10 | --vision_tower openai/clip-vit-large-patch14-336 \
11 | --pretrain_mm_mlp_adapter ./checkpoints/llava-v1.5-13b-pretrain/mm_projector.bin \
12 | --mm_projector_type mlp2x_gelu \
13 | --mm_vision_select_layer -2 \
14 | --mm_use_im_start_end False \
15 | --mm_use_im_patch_token False \
16 | --image_aspect_ratio pad \
17 | --group_by_modality_length True \
18 | --bf16 True \
19 | --output_dir ./checkpoints/llava-v1.5-13b-lora \
20 | --num_train_epochs 1 \
21 | --per_device_train_batch_size 16 \
22 | --per_device_eval_batch_size 4 \
23 | --gradient_accumulation_steps 1 \
24 | --evaluation_strategy "no" \
25 | --save_strategy "steps" \
26 | --save_steps 50000 \
27 | --save_total_limit 1 \
28 | --learning_rate 2e-4 \
29 | --weight_decay 0. \
30 | --warmup_ratio 0.03 \
31 | --lr_scheduler_type "cosine" \
32 | --logging_steps 1 \
33 | --tf32 True \
34 | --model_max_length 2048 \
35 | --gradient_checkpointing True \
36 | --dataloader_num_workers 4 \
37 | --lazy_preprocess True \
38 | --report_to wandb
39 |
--------------------------------------------------------------------------------
/scripts/v1_5/finetune_task.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | deepspeed llava/train/train_mem.py \
4 | --deepspeed ./scripts/zero3.json \
5 | --model_name_or_path liuhaotian/llava-v1.5-13b \
6 | --version v1 \
7 | --data_path ./playground/data/llava_v1_5_mix665k.json \
8 | --image_folder ./playground/data \
9 | --vision_tower openai/clip-vit-large-patch14-336 \
10 | --mm_projector_type mlp2x_gelu \
11 | --mm_vision_select_layer -2 \
12 | --mm_use_im_start_end False \
13 | --mm_use_im_patch_token False \
14 | --image_aspect_ratio pad \
15 | --group_by_modality_length True \
16 | --bf16 True \
17 | --output_dir ./checkpoints/llava-v1.5-13b-task \
18 | --num_train_epochs 1 \
19 | --per_device_train_batch_size 16 \
20 | --per_device_eval_batch_size 4 \
21 | --gradient_accumulation_steps 1 \
22 | --evaluation_strategy "no" \
23 | --save_strategy "steps" \
24 | --save_steps 50000 \
25 | --save_total_limit 1 \
26 | --learning_rate 2e-5 \
27 | --weight_decay 0. \
28 | --warmup_ratio 0.03 \
29 | --lr_scheduler_type "cosine" \
30 | --logging_steps 1 \
31 | --tf32 True \
32 | --model_max_length 2048 \
33 | --gradient_checkpointing True \
34 | --dataloader_num_workers 4 \
35 | --lazy_preprocess True \
36 | --report_to wandb
37 |
--------------------------------------------------------------------------------
/scripts/v1_5/finetune_task_lora.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | deepspeed llava/train/train_mem.py \
4 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
5 | --deepspeed ./scripts/zero3.json \
6 | --model_name_or_path liuhaotian/llava-v1.5-13b \
7 | --version v1 \
8 | --data_path ./playground/data/llava_v1_5_mix665k.json \
9 | --image_folder ./playground/data \
10 | --vision_tower openai/clip-vit-large-patch14-336 \
11 | --mm_projector_type mlp2x_gelu \
12 | --mm_vision_select_layer -2 \
13 | --mm_use_im_start_end False \
14 | --mm_use_im_patch_token False \
15 | --image_aspect_ratio pad \
16 | --group_by_modality_length True \
17 | --bf16 True \
18 | --output_dir ./checkpoints/llava-v1.5-13b-task-lora \
19 | --num_train_epochs 1 \
20 | --per_device_train_batch_size 16 \
21 | --per_device_eval_batch_size 4 \
22 | --gradient_accumulation_steps 1 \
23 | --evaluation_strategy "no" \
24 | --save_strategy "steps" \
25 | --save_steps 50000 \
26 | --save_total_limit 1 \
27 | --learning_rate 2e-4 \
28 | --weight_decay 0. \
29 | --warmup_ratio 0.03 \
30 | --lr_scheduler_type "cosine" \
31 | --logging_steps 1 \
32 | --tf32 True \
33 | --model_max_length 2048 \
34 | --gradient_checkpointing True \
35 | --dataloader_num_workers 4 \
36 | --lazy_preprocess True \
37 | --report_to wandb
38 |
--------------------------------------------------------------------------------
/scripts/v1_5/pretrain.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | deepspeed llava/train/train_mem.py \
4 | --deepspeed ./scripts/zero2.json \
5 | --model_name_or_path lmsys/vicuna-13b-v1.5 \
6 | --version plain \
7 | --data_path ./playground/data/LLaVA-Pretrain/blip_laion_cc_sbu_558k.json \
8 | --image_folder ./playground/data/LLaVA-Pretrain/images \
9 | --vision_tower openai/clip-vit-large-patch14-336 \
10 | --mm_projector_type mlp2x_gelu \
11 | --tune_mm_mlp_adapter True \
12 | --mm_vision_select_layer -2 \
13 | --mm_use_im_start_end False \
14 | --mm_use_im_patch_token False \
15 | --bf16 True \
16 | --output_dir ./checkpoints/llava-v1.5-13b-pretrain \
17 | --num_train_epochs 1 \
18 | --per_device_train_batch_size 32 \
19 | --per_device_eval_batch_size 4 \
20 | --gradient_accumulation_steps 1 \
21 | --evaluation_strategy "no" \
22 | --save_strategy "steps" \
23 | --save_steps 24000 \
24 | --save_total_limit 1 \
25 | --learning_rate 1e-3 \
26 | --weight_decay 0. \
27 | --warmup_ratio 0.03 \
28 | --lr_scheduler_type "cosine" \
29 | --logging_steps 1 \
30 | --tf32 True \
31 | --model_max_length 2048 \
32 | --gradient_checkpointing True \
33 | --dataloader_num_workers 4 \
34 | --lazy_preprocess True \
35 | --report_to wandb
36 |
--------------------------------------------------------------------------------
/scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto"
22 | }
23 | }
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
--------------------------------------------------------------------------------
/scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "scheduler": {
23 | "type": "WarmupLR",
24 | "params": {
25 | "warmup_min_lr": "auto",
26 | "warmup_max_lr": "auto",
27 | "warmup_num_steps": "auto"
28 | }
29 | },
30 | "zero_optimization": {
31 | "stage": 3,
32 | "offload_optimizer": {
33 | "device": "cpu",
34 | "pin_memory": true
35 | },
36 | "offload_param": {
37 | "device": "cpu",
38 | "pin_memory": true
39 | },
40 | "overlap_comm": true,
41 | "contiguous_gradients": true,
42 | "sub_group_size": 1e9,
43 | "reduce_bucket_size": "auto",
44 | "stage3_prefetch_bucket_size": "auto",
45 | "stage3_param_persistence_threshold": "auto",
46 | "stage3_max_live_parameters": 1e9,
47 | "stage3_max_reuse_distance": 1e9,
48 | "gather_16bit_weights_on_model_save": true
49 | },
50 | "gradient_accumulation_steps": "auto",
51 | "gradient_clipping": "auto",
52 | "train_batch_size": "auto",
53 | "train_micro_batch_size_per_gpu": "auto",
54 | "steps_per_print": 1e5,
55 | "wall_clock_breakdown": false
56 | }
--------------------------------------------------------------------------------
/taming_transformers/License.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
14 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
15 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
16 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
17 | DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
18 | OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
19 | OR OTHER DEALINGS IN THE SOFTWARE./
20 |
--------------------------------------------------------------------------------
/taming_transformers/ckpt/model.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: taming.models.vqgan.GumbelVQ
4 | params:
5 | kl_weight: 1.0e-08
6 | embed_dim: 256
7 | n_embed: 8192
8 | monitor: val/rec_loss
9 | temperature_scheduler_config:
10 | target: taming.lr_scheduler.LambdaWarmUpCosineScheduler
11 | params:
12 | warm_up_steps: 0
13 | max_decay_steps: 1000001
14 | lr_start: 0.9
15 | lr_max: 0.9
16 | lr_min: 1.0e-06
17 | ddconfig:
18 | double_z: false
19 | z_channels: 256
20 | resolution: 256
21 | in_channels: 3
22 | out_ch: 3
23 | ch: 128
24 | ch_mult:
25 | - 1
26 | - 1
27 | - 2
28 | - 4
29 | num_res_blocks: 2
30 | attn_resolutions:
31 | - 32
32 | dropout: 0.0
33 | lossconfig:
34 | target: taming.modules.losses.vqperceptual.DummyLoss
35 |
--------------------------------------------------------------------------------
/taming_transformers/environment.yaml:
--------------------------------------------------------------------------------
1 | name: taming
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=10.2
9 | - pytorch=1.7.0
10 | - torchvision=0.8.1
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - opencv-python==4.1.2.30
15 | - pudb==2019.2
16 | - imageio==2.9.0
17 | - imageio-ffmpeg==0.4.2
18 | - pytorch-lightning==1.0.8
19 | - omegaconf==2.0.0
20 | - test-tube>=0.7.5
21 | - streamlit>=0.73.1
22 | - einops==0.3.0
23 | - more-itertools>=8.0.0
24 | - transformers==4.3.1
25 | - -e .
26 |
--------------------------------------------------------------------------------
/taming_transformers/idx2img.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import numpy as np
3 | import torch
4 |
5 | import yaml
6 | import torch
7 | from omegaconf import OmegaConf
8 | from .taming.models.vqgan import VQModel, GumbelVQ
9 | import requests
10 | import PIL
11 | from PIL import Image
12 | from PIL import ImageDraw, ImageFont
13 | import numpy as np
14 |
15 | import torch
16 | import torch.nn.functional as F
17 | import torchvision.transforms as T
18 | import torchvision.transforms.functional as TF
19 | import pickle
20 | import os, sys
21 |
22 | torch.set_grad_enabled(False)
23 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24 |
25 | def preprocess_vqgan(x):
26 | x = 2.*x - 1.
27 | return x
28 |
29 | def custom_to_pil(x, save_path):
30 | x = x.detach().cpu()
31 | x = torch.clamp(x, -1., 1.)
32 | x = (x + 1.)/2.
33 | x = x[0]
34 | x = x.permute(1,2,0).numpy()
35 | x = (255*x).astype(np.uint8)
36 | x = Image.fromarray(x)
37 | if not x.mode == "RGB":
38 | x = x.convert("RGB")
39 | x.save(save_path)
40 | return x
41 |
42 | def sample2img(x):
43 | x = x.detach().cpu()
44 | x = torch.clamp(x, -1., 1.)
45 | x = (x + 1.)/2.
46 | x = x[0]
47 | x = x.permute(1,2,0).numpy()
48 | x = (255*x).astype(np.uint8)
49 | x = Image.fromarray(x)
50 | if not x.mode == "RGB":
51 | x = x.convert("RGB")
52 | return x
53 |
54 | def load_config(config_path, display=False):
55 | config = OmegaConf.load(config_path)
56 | if display:
57 | print(yaml.dump(OmegaConf.to_container(config)))
58 | return config
59 |
60 | def load_vqgan(config, ckpt_path=None, is_gumbel=False):
61 | if is_gumbel:
62 | model = GumbelVQ(**config.model.params)
63 | else:
64 | model = VQModel(**config.model.params)
65 | if ckpt_path is not None:
66 | sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
67 | missing, unexpected = model.load_state_dict(sd, strict=False)
68 | return model.eval()
69 |
70 | dir_path = os.path.dirname(__file__)
71 |
72 | config = load_config(os.path.join(dir_path, 'ckpt/model.yaml'), display=False)
73 | model = load_vqgan(config, ckpt_path=os.path.join(dir_path, 'ckpt/last.ckpt'), is_gumbel=True).to(device)
74 |
75 | @torch.no_grad()
76 | def decode_to_img(index, zshape=torch.randn((1,256,32,32)).shape):
77 | global model
78 | bhwc = (zshape[0],zshape[2],zshape[3],zshape[1])
79 | quant_z = model.quantize.get_codebook_entry(
80 | index.reshape(-1), shape=bhwc)
81 | x = model.decode(quant_z)
82 | return x
83 |
84 | def preprocess(img, target_image_size=512):
85 | img = TF.resize(img, (target_image_size, target_image_size), interpolation=PIL.Image.LANCZOS)
86 | img = torch.unsqueeze(T.ToTensor()(img), 0)
87 | return img
88 |
89 | @torch.no_grad()
90 | def img2idx(image_path):
91 | global model
92 | image = Image.open((image_path)).convert('RGB')
93 | img = preprocess_vqgan(preprocess(image, target_image_size=256).to(model.device))
94 |
95 | z, _, [_, _, indices] = model.encode(img)
96 | return z, indices
97 |
98 | @torch.no_grad()
99 | def idx2img(idx_tensor, save_path):
100 | x = decode_to_img(idx_tensor)
101 | custom_to_pil(x, save_path)
102 |
103 |
104 |
--------------------------------------------------------------------------------
/taming_transformers/scripts/extract_depth.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from tqdm import trange
5 | from PIL import Image
6 |
7 |
8 | def get_state(gpu):
9 | import torch
10 | midas = torch.hub.load("intel-isl/MiDaS", "MiDaS")
11 | if gpu:
12 | midas.cuda()
13 | midas.eval()
14 |
15 | midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
16 | transform = midas_transforms.default_transform
17 |
18 | state = {"model": midas,
19 | "transform": transform}
20 | return state
21 |
22 |
23 | def depth_to_rgba(x):
24 | assert x.dtype == np.float32
25 | assert len(x.shape) == 2
26 | y = x.copy()
27 | y.dtype = np.uint8
28 | y = y.reshape(x.shape+(4,))
29 | return np.ascontiguousarray(y)
30 |
31 |
32 | def rgba_to_depth(x):
33 | assert x.dtype == np.uint8
34 | assert len(x.shape) == 3 and x.shape[2] == 4
35 | y = x.copy()
36 | y.dtype = np.float32
37 | y = y.reshape(x.shape[:2])
38 | return np.ascontiguousarray(y)
39 |
40 |
41 | def run(x, state):
42 | model = state["model"]
43 | transform = state["transform"]
44 | hw = x.shape[:2]
45 | with torch.no_grad():
46 | prediction = model(transform((x + 1.0) * 127.5).cuda())
47 | prediction = torch.nn.functional.interpolate(
48 | prediction.unsqueeze(1),
49 | size=hw,
50 | mode="bicubic",
51 | align_corners=False,
52 | ).squeeze()
53 | output = prediction.cpu().numpy()
54 | return output
55 |
56 |
57 | def get_filename(relpath, level=-2):
58 | # save class folder structure and filename:
59 | fn = relpath.split(os.sep)[level:]
60 | folder = fn[-2]
61 | file = fn[-1].split('.')[0]
62 | return folder, file
63 |
64 |
65 | def save_depth(dataset, path, debug=False):
66 | os.makedirs(path)
67 | N = len(dset)
68 | if debug:
69 | N = 10
70 | state = get_state(gpu=True)
71 | for idx in trange(N, desc="Data"):
72 | ex = dataset[idx]
73 | image, relpath = ex["image"], ex["relpath"]
74 | folder, filename = get_filename(relpath)
75 | # prepare
76 | folderabspath = os.path.join(path, folder)
77 | os.makedirs(folderabspath, exist_ok=True)
78 | savepath = os.path.join(folderabspath, filename)
79 | # run model
80 | xout = run(image, state)
81 | I = depth_to_rgba(xout)
82 | Image.fromarray(I).save("{}.png".format(savepath))
83 |
84 |
85 | if __name__ == "__main__":
86 | from taming.data.imagenet import ImageNetTrain, ImageNetValidation
87 | out = "data/imagenet_depth"
88 | if not os.path.exists(out):
89 | print("Please create a folder or symlink '{}' to extract depth data ".format(out) +
90 | "(be prepared that the output size will be larger than ImageNet itself).")
91 | exit(1)
92 |
93 | # go
94 | dset = ImageNetValidation()
95 | abspath = os.path.join(out, "val")
96 | if os.path.exists(abspath):
97 | print("{} exists - not doing anything.".format(abspath))
98 | else:
99 | print("preparing {}".format(abspath))
100 | save_depth(dset, abspath)
101 | print("done with validation split")
102 |
103 | dset = ImageNetTrain()
104 | abspath = os.path.join(out, "train")
105 | if os.path.exists(abspath):
106 | print("{} exists - not doing anything.".format(abspath))
107 | else:
108 | print("preparing {}".format(abspath))
109 | save_depth(dset, abspath)
110 | print("done with train split")
111 |
112 | print("done done.")
113 |
--------------------------------------------------------------------------------
/taming_transformers/scripts/extract_segmentation.py:
--------------------------------------------------------------------------------
1 | import sys, os
2 | import numpy as np
3 | import scipy
4 | import torch
5 | import torch.nn as nn
6 | from scipy import ndimage
7 | from tqdm import tqdm, trange
8 | from PIL import Image
9 | import torch.hub
10 | import torchvision
11 | import torch.nn.functional as F
12 |
13 | # download deeplabv2_resnet101_msc-cocostuff164k-100000.pth from
14 | # https://github.com/kazuto1011/deeplab-pytorch/releases/download/v1.0/deeplabv2_resnet101_msc-cocostuff164k-100000.pth
15 | # and put the path here
16 | CKPT_PATH = "TODO"
17 |
18 | rescale = lambda x: (x + 1.) / 2.
19 |
20 | def rescale_bgr(x):
21 | x = (x+1)*127.5
22 | x = torch.flip(x, dims=[0])
23 | return x
24 |
25 |
26 | class COCOStuffSegmenter(nn.Module):
27 | def __init__(self, config):
28 | super().__init__()
29 | self.config = config
30 | self.n_labels = 182
31 | model = torch.hub.load("kazuto1011/deeplab-pytorch", "deeplabv2_resnet101", n_classes=self.n_labels)
32 | ckpt_path = CKPT_PATH
33 | model.load_state_dict(torch.load(ckpt_path))
34 | self.model = model
35 |
36 | normalize = torchvision.transforms.Normalize(mean=self.mean, std=self.std)
37 | self.image_transform = torchvision.transforms.Compose([
38 | torchvision.transforms.Lambda(lambda image: torch.stack(
39 | [normalize(rescale_bgr(x)) for x in image]))
40 | ])
41 |
42 | def forward(self, x, upsample=None):
43 | x = self._pre_process(x)
44 | x = self.model(x)
45 | if upsample is not None:
46 | x = torch.nn.functional.upsample_bilinear(x, size=upsample)
47 | return x
48 |
49 | def _pre_process(self, x):
50 | x = self.image_transform(x)
51 | return x
52 |
53 | @property
54 | def mean(self):
55 | # bgr
56 | return [104.008, 116.669, 122.675]
57 |
58 | @property
59 | def std(self):
60 | return [1.0, 1.0, 1.0]
61 |
62 | @property
63 | def input_size(self):
64 | return [3, 224, 224]
65 |
66 |
67 | def run_model(img, model):
68 | model = model.eval()
69 | with torch.no_grad():
70 | segmentation = model(img, upsample=(img.shape[2], img.shape[3]))
71 | segmentation = torch.argmax(segmentation, dim=1, keepdim=True)
72 | return segmentation.detach().cpu()
73 |
74 |
75 | def get_input(batch, k):
76 | x = batch[k]
77 | if len(x.shape) == 3:
78 | x = x[..., None]
79 | x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format)
80 | return x.float()
81 |
82 |
83 | def save_segmentation(segmentation, path):
84 | # --> class label to uint8, save as png
85 | os.makedirs(os.path.dirname(path), exist_ok=True)
86 | assert len(segmentation.shape)==4
87 | assert segmentation.shape[0]==1
88 | for seg in segmentation:
89 | seg = seg.permute(1,2,0).numpy().squeeze().astype(np.uint8)
90 | seg = Image.fromarray(seg)
91 | seg.save(path)
92 |
93 |
94 | def iterate_dataset(dataloader, destpath, model):
95 | os.makedirs(destpath, exist_ok=True)
96 | num_processed = 0
97 | for i, batch in tqdm(enumerate(dataloader), desc="Data"):
98 | try:
99 | img = get_input(batch, "image")
100 | img = img.cuda()
101 | seg = run_model(img, model)
102 |
103 | path = batch["relative_file_path_"][0]
104 | path = os.path.splitext(path)[0]
105 |
106 | path = os.path.join(destpath, path + ".png")
107 | save_segmentation(seg, path)
108 | num_processed += 1
109 | except Exception as e:
110 | print(e)
111 | print("but anyhow..")
112 |
113 | print("Processed {} files. Bye.".format(num_processed))
114 |
115 |
116 | from taming.data.sflckr import Examples
117 | from torch.utils.data import DataLoader
118 |
119 | if __name__ == "__main__":
120 | dest = sys.argv[1]
121 | batchsize = 1
122 | print("Running with batch-size {}, saving to {}...".format(batchsize, dest))
123 |
124 | model = COCOStuffSegmenter({}).cuda()
125 | print("Instantiated model.")
126 |
127 | dataset = Examples()
128 | dloader = DataLoader(dataset, batch_size=batchsize)
129 | iterate_dataset(dataloader=dloader, destpath=dest, model=model)
130 | print("done.")
131 |
--------------------------------------------------------------------------------
/taming_transformers/scripts/extract_submodel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import sys
3 |
4 | if __name__ == "__main__":
5 | inpath = sys.argv[1]
6 | outpath = sys.argv[2]
7 | submodel = "cond_stage_model"
8 | if len(sys.argv) > 3:
9 | submodel = sys.argv[3]
10 |
11 | print("Extracting {} from {} to {}.".format(submodel, inpath, outpath))
12 |
13 | sd = torch.load(inpath, map_location="cpu")
14 | new_sd = {"state_dict": dict((k.split(".", 1)[-1],v)
15 | for k,v in sd["state_dict"].items()
16 | if k.startswith("cond_stage_model"))}
17 | torch.save(new_sd, outpath)
18 |
--------------------------------------------------------------------------------
/taming_transformers/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='taming-transformers',
5 | version='0.0.1',
6 | description='Taming Transformers for High-Resolution Image Synthesis',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
14 |
--------------------------------------------------------------------------------
/taming_transformers/taming/data/base.py:
--------------------------------------------------------------------------------
1 | import bisect
2 | import numpy as np
3 | import albumentations
4 | from PIL import Image
5 | from torch.utils.data import Dataset, ConcatDataset
6 |
7 |
8 | class ConcatDatasetWithIndex(ConcatDataset):
9 | """Modified from original pytorch code to return dataset idx"""
10 | def __getitem__(self, idx):
11 | if idx < 0:
12 | if -idx > len(self):
13 | raise ValueError("absolute value of index should not exceed dataset length")
14 | idx = len(self) + idx
15 | dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
16 | if dataset_idx == 0:
17 | sample_idx = idx
18 | else:
19 | sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
20 | return self.datasets[dataset_idx][sample_idx], dataset_idx
21 |
22 |
23 | class ImagePaths(Dataset):
24 | def __init__(self, paths, size=None, random_crop=False, labels=None):
25 | self.size = size
26 | self.random_crop = random_crop
27 |
28 | self.labels = dict() if labels is None else labels
29 | self.labels["file_path_"] = paths
30 | self._length = len(paths)
31 |
32 | if self.size is not None and self.size > 0:
33 | self.rescaler = albumentations.SmallestMaxSize(max_size = self.size)
34 | if not self.random_crop:
35 | self.cropper = albumentations.CenterCrop(height=self.size,width=self.size)
36 | else:
37 | self.cropper = albumentations.RandomCrop(height=self.size,width=self.size)
38 | self.preprocessor = albumentations.Compose([self.rescaler, self.cropper])
39 | else:
40 | self.preprocessor = lambda **kwargs: kwargs
41 |
42 | def __len__(self):
43 | return self._length
44 |
45 | def preprocess_image(self, image_path):
46 | image = Image.open(image_path)
47 | if not image.mode == "RGB":
48 | image = image.convert("RGB")
49 | image = np.array(image).astype(np.uint8)
50 | image = self.preprocessor(image=image)["image"]
51 | image = (image/127.5 - 1.0).astype(np.float32)
52 | return image
53 |
54 | def __getitem__(self, i):
55 | example = dict()
56 | example["image"] = self.preprocess_image(self.labels["file_path_"][i])
57 | for k in self.labels:
58 | example[k] = self.labels[k][i]
59 | return example
60 |
61 |
62 | class NumpyPaths(ImagePaths):
63 | def preprocess_image(self, image_path):
64 | image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024
65 | image = np.transpose(image, (1,2,0))
66 | image = Image.fromarray(image, mode="RGB")
67 | image = np.array(image).astype(np.uint8)
68 | image = self.preprocessor(image=image)["image"]
69 | image = (image/127.5 - 1.0).astype(np.float32)
70 | return image
71 |
--------------------------------------------------------------------------------
/taming_transformers/taming/data/conditional_builder/objects_bbox.py:
--------------------------------------------------------------------------------
1 | from itertools import cycle
2 | from typing import List, Tuple, Callable, Optional
3 |
4 | from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont
5 | from more_itertools.recipes import grouper
6 | from taming.data.image_transforms import convert_pil_to_tensor
7 | from torch import LongTensor, Tensor
8 |
9 | from taming.data.helper_types import BoundingBox, Annotation
10 | from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
11 | from taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \
12 | pad_list, get_plot_font_size, absolute_bbox
13 |
14 |
15 | class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder):
16 | @property
17 | def object_descriptor_length(self) -> int:
18 | return 3
19 |
20 | def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]:
21 | object_triples = [
22 | (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox))
23 | for ann in annotations
24 | ]
25 | empty_triple = (self.none, self.none, self.none)
26 | object_triples = pad_list(object_triples, empty_triple, self.no_max_objects)
27 | return object_triples
28 |
29 | def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]:
30 | conditional_list = conditional.tolist()
31 | crop_coordinates = None
32 | if self.encode_crop:
33 | crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1])
34 | conditional_list = conditional_list[:-2]
35 | object_triples = grouper(conditional_list, 3)
36 | assert conditional.shape[0] == self.embedding_dim
37 | return [
38 | (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2]))
39 | for object_triple in object_triples if object_triple[0] != self.none
40 | ], crop_coordinates
41 |
42 | def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int],
43 | line_width: int = 3, font_size: Optional[int] = None) -> Tensor:
44 | plot = pil_image.new('RGB', figure_size, WHITE)
45 | draw = pil_img_draw.Draw(plot)
46 | font = ImageFont.truetype(
47 | "/usr/share/fonts/truetype/lato/Lato-Regular.ttf",
48 | size=get_plot_font_size(font_size, figure_size)
49 | )
50 | width, height = plot.size
51 | description, crop_coordinates = self.inverse_build(conditional)
52 | for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)):
53 | annotation = self.representation_to_annotation(representation)
54 | class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation)
55 | bbox = absolute_bbox(bbox, width, height)
56 | draw.rectangle(bbox, outline=color, width=line_width)
57 | draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font)
58 | if crop_coordinates is not None:
59 | draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width)
60 | return convert_pil_to_tensor(plot) / 127.5 - 1.
61 |
--------------------------------------------------------------------------------
/taming_transformers/taming/data/conditional_builder/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from typing import List, Any, Tuple, Optional
3 |
4 | from taming.data.helper_types import BoundingBox, Annotation
5 |
6 | # source: seaborn, color palette tab10
7 | COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188),
8 | (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)]
9 | BLACK = (0, 0, 0)
10 | GRAY_75 = (63, 63, 63)
11 | GRAY_50 = (127, 127, 127)
12 | GRAY_25 = (191, 191, 191)
13 | WHITE = (255, 255, 255)
14 | FULL_CROP = (0., 0., 1., 1.)
15 |
16 |
17 | def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float:
18 | """
19 | Give intersection area of two rectangles.
20 | @param rectangle1: (x0, y0, w, h) of first rectangle
21 | @param rectangle2: (x0, y0, w, h) of second rectangle
22 | """
23 | rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3]
24 | rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3]
25 | x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0]))
26 | y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1]))
27 | return x_overlap * y_overlap
28 |
29 |
30 | def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox:
31 | return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3]
32 |
33 |
34 | def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]:
35 | bbox = relative_bbox
36 | bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height
37 | return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3])
38 |
39 |
40 | def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List:
41 | return list_ + [pad_element for _ in range(pad_to_length - len(list_))]
42 |
43 |
44 | def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \
45 | List[Annotation]:
46 | def clamp(x: float):
47 | return max(min(x, 1.), 0.)
48 |
49 | def rescale_bbox(bbox: BoundingBox) -> BoundingBox:
50 | x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2])
51 | y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3])
52 | w = min(bbox[2] / crop_coordinates[2], 1 - x0)
53 | h = min(bbox[3] / crop_coordinates[3], 1 - y0)
54 | if flip:
55 | x0 = 1 - (x0 + w)
56 | return x0, y0, w, h
57 |
58 | return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations]
59 |
60 |
61 | def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List:
62 | return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0]
63 |
64 |
65 | def additional_parameters_string(annotation: Annotation, short: bool = True) -> str:
66 | sl = slice(1) if short else slice(None)
67 | string = ''
68 | if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside):
69 | return string
70 | if annotation.is_group_of:
71 | string += 'group'[sl] + ','
72 | if annotation.is_occluded:
73 | string += 'occluded'[sl] + ','
74 | if annotation.is_depiction:
75 | string += 'depiction'[sl] + ','
76 | if annotation.is_inside:
77 | string += 'inside'[sl]
78 | return '(' + string.strip(",") + ')'
79 |
80 |
81 | def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int:
82 | if font_size is None:
83 | font_size = 10
84 | if max(figure_size) >= 256:
85 | font_size = 12
86 | if max(figure_size) >= 512:
87 | font_size = 15
88 | return font_size
89 |
90 |
91 | def get_circle_size(figure_size: Tuple[int, int]) -> int:
92 | circle_size = 2
93 | if max(figure_size) >= 256:
94 | circle_size = 3
95 | if max(figure_size) >= 512:
96 | circle_size = 4
97 | return circle_size
98 |
99 |
100 | def load_object_from_string(object_string: str) -> Any:
101 | """
102 | Source: https://stackoverflow.com/a/10773699
103 | """
104 | module_name, class_name = object_string.rsplit(".", 1)
105 | return getattr(importlib.import_module(module_name), class_name)
106 |
--------------------------------------------------------------------------------
/taming_transformers/taming/data/custom.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import albumentations
4 | from torch.utils.data import Dataset
5 |
6 | from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
7 |
8 |
9 | class CustomBase(Dataset):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__()
12 | self.data = None
13 |
14 | def __len__(self):
15 | return len(self.data)
16 |
17 | def __getitem__(self, i):
18 | example = self.data[i]
19 | return example
20 |
21 |
22 |
23 | class CustomTrain(CustomBase):
24 | def __init__(self, size, training_images_list_file):
25 | super().__init__()
26 | with open(training_images_list_file, "r") as f:
27 | paths = f.read().splitlines()
28 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
29 |
30 |
31 | class CustomTest(CustomBase):
32 | def __init__(self, size, test_images_list_file):
33 | super().__init__()
34 | with open(test_images_list_file, "r") as f:
35 | paths = f.read().splitlines()
36 | self.data = ImagePaths(paths=paths, size=size, random_crop=False)
37 |
38 |
39 |
--------------------------------------------------------------------------------
/taming_transformers/taming/data/helper_types.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple, Optional, NamedTuple, Union
2 | from PIL.Image import Image as pil_image
3 | from torch import Tensor
4 |
5 | try:
6 | from typing import Literal
7 | except ImportError:
8 | from typing_extensions import Literal
9 |
10 | Image = Union[Tensor, pil_image]
11 | BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h
12 | CropMethodType = Literal['none', 'random', 'center', 'random-2d']
13 | SplitType = Literal['train', 'validation', 'test']
14 |
15 |
16 | class ImageDescription(NamedTuple):
17 | id: int
18 | file_name: str
19 | original_size: Tuple[int, int] # w, h
20 | url: Optional[str] = None
21 | license: Optional[int] = None
22 | coco_url: Optional[str] = None
23 | date_captured: Optional[str] = None
24 | flickr_url: Optional[str] = None
25 | flickr_id: Optional[str] = None
26 | coco_id: Optional[str] = None
27 |
28 |
29 | class Category(NamedTuple):
30 | id: str
31 | super_category: Optional[str]
32 | name: str
33 |
34 |
35 | class Annotation(NamedTuple):
36 | area: float
37 | image_id: str
38 | bbox: BoundingBox
39 | category_no: int
40 | category_id: str
41 | id: Optional[int] = None
42 | source: Optional[str] = None
43 | confidence: Optional[float] = None
44 | is_group_of: Optional[bool] = None
45 | is_truncated: Optional[bool] = None
46 | is_occluded: Optional[bool] = None
47 | is_depiction: Optional[bool] = None
48 | is_inside: Optional[bool] = None
49 | segmentation: Optional[Dict] = None
50 |
--------------------------------------------------------------------------------
/taming_transformers/taming/data/sflckr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import albumentations
5 | from PIL import Image
6 | from torch.utils.data import Dataset
7 |
8 |
9 | class SegmentationBase(Dataset):
10 | def __init__(self,
11 | data_csv, data_root, segmentation_root,
12 | size=None, random_crop=False, interpolation="bicubic",
13 | n_labels=182, shift_segmentation=False,
14 | ):
15 | self.n_labels = n_labels
16 | self.shift_segmentation = shift_segmentation
17 | self.data_csv = data_csv
18 | self.data_root = data_root
19 | self.segmentation_root = segmentation_root
20 | with open(self.data_csv, "r") as f:
21 | self.image_paths = f.read().splitlines()
22 | self._length = len(self.image_paths)
23 | self.labels = {
24 | "relative_file_path_": [l for l in self.image_paths],
25 | "file_path_": [os.path.join(self.data_root, l)
26 | for l in self.image_paths],
27 | "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png"))
28 | for l in self.image_paths]
29 | }
30 |
31 | size = None if size is not None and size<=0 else size
32 | self.size = size
33 | if self.size is not None:
34 | self.interpolation = interpolation
35 | self.interpolation = {
36 | "nearest": cv2.INTER_NEAREST,
37 | "bilinear": cv2.INTER_LINEAR,
38 | "bicubic": cv2.INTER_CUBIC,
39 | "area": cv2.INTER_AREA,
40 | "lanczos": cv2.INTER_LANCZOS4}[self.interpolation]
41 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
42 | interpolation=self.interpolation)
43 | self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size,
44 | interpolation=cv2.INTER_NEAREST)
45 | self.center_crop = not random_crop
46 | if self.center_crop:
47 | self.cropper = albumentations.CenterCrop(height=self.size, width=self.size)
48 | else:
49 | self.cropper = albumentations.RandomCrop(height=self.size, width=self.size)
50 | self.preprocessor = self.cropper
51 |
52 | def __len__(self):
53 | return self._length
54 |
55 | def __getitem__(self, i):
56 | example = dict((k, self.labels[k][i]) for k in self.labels)
57 | image = Image.open(example["file_path_"])
58 | if not image.mode == "RGB":
59 | image = image.convert("RGB")
60 | image = np.array(image).astype(np.uint8)
61 | if self.size is not None:
62 | image = self.image_rescaler(image=image)["image"]
63 | segmentation = Image.open(example["segmentation_path_"])
64 | assert segmentation.mode == "L", segmentation.mode
65 | segmentation = np.array(segmentation).astype(np.uint8)
66 | if self.shift_segmentation:
67 | # used to support segmentations containing unlabeled==255 label
68 | segmentation = segmentation+1
69 | if self.size is not None:
70 | segmentation = self.segmentation_rescaler(image=segmentation)["image"]
71 | if self.size is not None:
72 | processed = self.preprocessor(image=image,
73 | mask=segmentation
74 | )
75 | else:
76 | processed = {"image": image,
77 | "mask": segmentation
78 | }
79 | example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32)
80 | segmentation = processed["mask"]
81 | onehot = np.eye(self.n_labels)[segmentation]
82 | example["segmentation"] = onehot
83 | return example
84 |
85 |
86 | class Examples(SegmentationBase):
87 | def __init__(self, size=None, random_crop=False, interpolation="bicubic"):
88 | super().__init__(data_csv="data/sflckr_examples.txt",
89 | data_root="data/sflckr_images",
90 | segmentation_root="data/sflckr_segmentations",
91 | size=size, random_crop=random_crop, interpolation=interpolation)
92 |
--------------------------------------------------------------------------------
/taming_transformers/taming/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n):
33 | return self.schedule(n)
34 |
35 |
--------------------------------------------------------------------------------
/taming_transformers/taming/models/__pycache__/vqgan.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/models/__pycache__/vqgan.cpython-310.pyc
--------------------------------------------------------------------------------
/taming_transformers/taming/models/dummy_cond_stage.py:
--------------------------------------------------------------------------------
1 | from torch import Tensor
2 |
3 |
4 | class DummyCondStage:
5 | def __init__(self, conditional_key):
6 | self.conditional_key = conditional_key
7 | self.train = None
8 |
9 | def eval(self):
10 | return self
11 |
12 | @staticmethod
13 | def encode(c: Tensor):
14 | return c, None, (None, None, c)
15 |
16 | @staticmethod
17 | def decode(c: Tensor):
18 | return c
19 |
20 | @staticmethod
21 | def to_rgb(c: Tensor):
22 | return c
23 |
--------------------------------------------------------------------------------
/taming_transformers/taming/modules/__pycache__/util.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/__pycache__/util.cpython-310.pyc
--------------------------------------------------------------------------------
/taming_transformers/taming/modules/diffusionmodules/__pycache__/model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/diffusionmodules/__pycache__/model.cpython-310.pyc
--------------------------------------------------------------------------------
/taming_transformers/taming/modules/discriminator/__pycache__/model.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/discriminator/__pycache__/model.cpython-310.pyc
--------------------------------------------------------------------------------
/taming_transformers/taming/modules/discriminator/model.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch.nn as nn
3 |
4 |
5 | from taming.modules.util import ActNorm
6 |
7 |
8 | def weights_init(m):
9 | classname = m.__class__.__name__
10 | if classname.find('Conv') != -1:
11 | nn.init.normal_(m.weight.data, 0.0, 0.02)
12 | elif classname.find('BatchNorm') != -1:
13 | nn.init.normal_(m.weight.data, 1.0, 0.02)
14 | nn.init.constant_(m.bias.data, 0)
15 |
16 |
17 | class NLayerDiscriminator(nn.Module):
18 | """Defines a PatchGAN discriminator as in Pix2Pix
19 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
20 | """
21 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
22 | """Construct a PatchGAN discriminator
23 | Parameters:
24 | input_nc (int) -- the number of channels in input images
25 | ndf (int) -- the number of filters in the last conv layer
26 | n_layers (int) -- the number of conv layers in the discriminator
27 | norm_layer -- normalization layer
28 | """
29 | super(NLayerDiscriminator, self).__init__()
30 | if not use_actnorm:
31 | norm_layer = nn.BatchNorm2d
32 | else:
33 | norm_layer = ActNorm
34 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
35 | use_bias = norm_layer.func != nn.BatchNorm2d
36 | else:
37 | use_bias = norm_layer != nn.BatchNorm2d
38 |
39 | kw = 4
40 | padw = 1
41 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
42 | nf_mult = 1
43 | nf_mult_prev = 1
44 | for n in range(1, n_layers): # gradually increase the number of filters
45 | nf_mult_prev = nf_mult
46 | nf_mult = min(2 ** n, 8)
47 | sequence += [
48 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
49 | norm_layer(ndf * nf_mult),
50 | nn.LeakyReLU(0.2, True)
51 | ]
52 |
53 | nf_mult_prev = nf_mult
54 | nf_mult = min(2 ** n_layers, 8)
55 | sequence += [
56 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
57 | norm_layer(ndf * nf_mult),
58 | nn.LeakyReLU(0.2, True)
59 | ]
60 |
61 | sequence += [
62 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
63 | self.main = nn.Sequential(*sequence)
64 |
65 | def forward(self, input):
66 | """Standard forward."""
67 | return self.main(input)
68 |
--------------------------------------------------------------------------------
/taming_transformers/taming/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from taming.modules.losses.vqperceptual import DummyLoss
2 |
3 |
--------------------------------------------------------------------------------
/taming_transformers/taming/modules/losses/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/losses/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/taming_transformers/taming/modules/losses/__pycache__/lpips.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/losses/__pycache__/lpips.cpython-310.pyc
--------------------------------------------------------------------------------
/taming_transformers/taming/modules/losses/__pycache__/vqperceptual.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/losses/__pycache__/vqperceptual.cpython-310.pyc
--------------------------------------------------------------------------------
/taming_transformers/taming/modules/losses/segmentation.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class BCELoss(nn.Module):
6 | def forward(self, prediction, target):
7 | loss = F.binary_cross_entropy_with_logits(prediction,target)
8 | return loss, {}
9 |
10 |
11 | class BCELossWithQuant(nn.Module):
12 | def __init__(self, codebook_weight=1.):
13 | super().__init__()
14 | self.codebook_weight = codebook_weight
15 |
16 | def forward(self, qloss, target, prediction, split):
17 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
18 | loss = bce_loss + self.codebook_weight*qloss
19 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
20 | "{}/bce_loss".format(split): bce_loss.detach().mean(),
21 | "{}/quant_loss".format(split): qloss.detach().mean()
22 | }
23 |
--------------------------------------------------------------------------------
/taming_transformers/taming/modules/misc/coord.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class CoordStage(object):
4 | def __init__(self, n_embed, down_factor):
5 | self.n_embed = n_embed
6 | self.down_factor = down_factor
7 |
8 | def eval(self):
9 | return self
10 |
11 | def encode(self, c):
12 | """fake vqmodel interface"""
13 | assert 0.0 <= c.min() and c.max() <= 1.0
14 | b,ch,h,w = c.shape
15 | assert ch == 1
16 |
17 | c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor,
18 | mode="area")
19 | c = c.clamp(0.0, 1.0)
20 | c = self.n_embed*c
21 | c_quant = c.round()
22 | c_ind = c_quant.to(dtype=torch.long)
23 |
24 | info = None, None, c_ind
25 | return c_quant, None, info
26 |
27 | def decode(self, c):
28 | c = c/self.n_embed
29 | c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor,
30 | mode="nearest")
31 | return c
32 |
--------------------------------------------------------------------------------
/taming_transformers/taming/modules/util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def count_params(model):
6 | total_params = sum(p.numel() for p in model.parameters())
7 | return total_params
8 |
9 |
10 | class ActNorm(nn.Module):
11 | def __init__(self, num_features, logdet=False, affine=True,
12 | allow_reverse_init=False):
13 | assert affine
14 | super().__init__()
15 | self.logdet = logdet
16 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
17 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
18 | self.allow_reverse_init = allow_reverse_init
19 |
20 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
21 |
22 | def initialize(self, input):
23 | with torch.no_grad():
24 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
25 | mean = (
26 | flatten.mean(1)
27 | .unsqueeze(1)
28 | .unsqueeze(2)
29 | .unsqueeze(3)
30 | .permute(1, 0, 2, 3)
31 | )
32 | std = (
33 | flatten.std(1)
34 | .unsqueeze(1)
35 | .unsqueeze(2)
36 | .unsqueeze(3)
37 | .permute(1, 0, 2, 3)
38 | )
39 |
40 | self.loc.data.copy_(-mean)
41 | self.scale.data.copy_(1 / (std + 1e-6))
42 |
43 | def forward(self, input, reverse=False):
44 | if reverse:
45 | return self.reverse(input)
46 | if len(input.shape) == 2:
47 | input = input[:,:,None,None]
48 | squeeze = True
49 | else:
50 | squeeze = False
51 |
52 | _, _, height, width = input.shape
53 |
54 | if self.training and self.initialized.item() == 0:
55 | self.initialize(input)
56 | self.initialized.fill_(1)
57 |
58 | h = self.scale * (input + self.loc)
59 |
60 | if squeeze:
61 | h = h.squeeze(-1).squeeze(-1)
62 |
63 | if self.logdet:
64 | log_abs = torch.log(torch.abs(self.scale))
65 | logdet = height*width*torch.sum(log_abs)
66 | logdet = logdet * torch.ones(input.shape[0]).to(input)
67 | return h, logdet
68 |
69 | return h
70 |
71 | def reverse(self, output):
72 | if self.training and self.initialized.item() == 0:
73 | if not self.allow_reverse_init:
74 | raise RuntimeError(
75 | "Initializing ActNorm in reverse direction is "
76 | "disabled by default. Use allow_reverse_init=True to enable."
77 | )
78 | else:
79 | self.initialize(output)
80 | self.initialized.fill_(1)
81 |
82 | if len(output.shape) == 2:
83 | output = output[:,:,None,None]
84 | squeeze = True
85 | else:
86 | squeeze = False
87 |
88 | h = output / self.scale - self.loc
89 |
90 | if squeeze:
91 | h = h.squeeze(-1).squeeze(-1)
92 | return h
93 |
94 |
95 | class AbstractEncoder(nn.Module):
96 | def __init__(self):
97 | super().__init__()
98 |
99 | def encode(self, *args, **kwargs):
100 | raise NotImplementedError
101 |
102 |
103 | class Labelator(AbstractEncoder):
104 | """Net2Net Interface for Class-Conditional Model"""
105 | def __init__(self, n_classes, quantize_interface=True):
106 | super().__init__()
107 | self.n_classes = n_classes
108 | self.quantize_interface = quantize_interface
109 |
110 | def encode(self, c):
111 | c = c[:,None]
112 | if self.quantize_interface:
113 | return c, None, [None, None, c.long()]
114 | return c
115 |
116 |
117 | class SOSProvider(AbstractEncoder):
118 | # for unconditional training
119 | def __init__(self, sos_token, quantize_interface=True):
120 | super().__init__()
121 | self.sos_token = sos_token
122 | self.quantize_interface = quantize_interface
123 |
124 | def encode(self, x):
125 | # get batch size from data and replicate sos_token
126 | c = torch.ones(x.shape[0], 1)*self.sos_token
127 | c = c.long().to(x.device)
128 | if self.quantize_interface:
129 | return c, None, [None, None, c]
130 | return c
131 |
--------------------------------------------------------------------------------
/taming_transformers/taming/modules/vqvae/__pycache__/quantize.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/DCDmllm/HealthGPT/c044a13254b76c5eec8c2e6c55e3324318c27940/taming_transformers/taming/modules/vqvae/__pycache__/quantize.cpython-310.pyc
--------------------------------------------------------------------------------