├── .dockerignore
├── .editorconfig
├── .gitattributes
├── .github
└── ISSUE_TEMPLATE
│ ├── 1-usage.yaml
│ ├── 2-feature-request.yaml
│ ├── 3-question.yaml
│ └── 4-discussion.yaml
├── .gitignore
├── LICENSE
├── README.md
├── cog.yaml
├── docs
├── Installation.md
├── all_float_paths.json
├── images
│ ├── black_img.jpg
│ ├── img_recon.gif
│ ├── teaser.png
│ ├── text_generation.png
│ ├── video_edit.gif
│ └── video_edit_2.gif
├── postprocess.md
└── prompts
│ ├── detailed_textbased_description.txt
│ ├── gpt4v_prompt_garment_sam.txt
│ ├── prompt_garment_editing.txt
│ ├── prompt_garment_part_inference.txt
│ └── smplified_image_description.txt
├── example_data
├── example_imgs
│ ├── 1aee14a8c7b4d56b4e8b6ddd575d1f561a72fdc75c43a4b6926f1655152193c6.png
│ ├── 1dde6afed43187fe927089a615e3f744724ef3defdf3f2ae4a6cede5ad71dcea.png
│ ├── 62bb809fc2dcd50409cb36163a0eb222f9aa1af0f256a3233b67b3ed4081dc71.png
│ ├── 6fe14e1f646513ee93714fbe8026a84c6a2897be4df2f3c936cb2be8dd2d1762.png
│ ├── 72b086429d2dfe2a8de6f4403a024b2bb17446021c9e8f9ebacfc7a990ac8434.png
│ ├── 80141ce740f489f1d2f57a03f32c7577a28b62a6ac790a0d9ed8a18d961c2918.png
│ ├── 8e3c458da20c290c216813ec07f1a2e8f9cfb4ee7e412a783a238ec353b346a0.png
│ ├── c2b582eb318455abaf8ed8e3126c1b423ade2704d810f7cd24428febda5632fa.png
│ ├── d77c6f5d4856831878eadb7fe3c8b180bfa9e9ad4a14936ac10a1697bb3c054f.png
│ └── e918651cc154a7570e47d8b8f6c0f0f93cfbb7d5129103a1bacd8299ba945f91.png
├── example_jsons
│ ├── example_edit_prompts.json
│ └── example_textgen_prompts.json
└── example_sewing_patterns
│ └── example_shirt
│ ├── design.yaml
│ └── valid_garment_upper_render_front.png
├── llava
├── __init__.py
├── close_utils.py
├── constants.py
├── conversation.py
├── eval
│ ├── eval_gpt_review.py
│ ├── eval_gpt_review_bench.py
│ ├── eval_gpt_review_visual.py
│ ├── eval_pope.py
│ ├── eval_science_qa.py
│ ├── eval_science_qa_gpt4.py
│ ├── eval_science_qa_gpt4_requery.py
│ ├── eval_textvqa.py
│ ├── generate_webpage_data_from_table.py
│ ├── m4c_evaluator.py
│ ├── model_qa.py
│ ├── model_vqa.py
│ ├── model_vqa_loader.py
│ ├── model_vqa_mmbench.py
│ ├── model_vqa_science.py
│ ├── qa_baseline_gpt35.py
│ ├── run_llava.py
│ ├── summarize_gpt_review.py
│ ├── table
│ │ ├── answer
│ │ │ ├── answer_alpaca-13b.jsonl
│ │ │ ├── answer_bard.jsonl
│ │ │ ├── answer_gpt35.jsonl
│ │ │ ├── answer_llama-13b.jsonl
│ │ │ └── answer_vicuna-13b.jsonl
│ │ ├── caps_boxes_coco2014_val_80.jsonl
│ │ ├── model.jsonl
│ │ ├── prompt.jsonl
│ │ ├── question.jsonl
│ │ ├── results
│ │ │ ├── test_sqa_llava_13b_v0.json
│ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json
│ │ ├── review
│ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl
│ │ │ ├── review_bard_vicuna-13b.jsonl
│ │ │ ├── review_gpt35_vicuna-13b.jsonl
│ │ │ └── review_llama-13b_vicuna-13b.jsonl
│ │ ├── reviewer.jsonl
│ │ └── rule.json
│ └── webpage
│ │ ├── figures
│ │ ├── alpaca.png
│ │ ├── bard.jpg
│ │ ├── chatgpt.svg
│ │ ├── llama.jpg
│ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg
│ │ └── vicuna.jpeg
│ │ ├── index.html
│ │ ├── script.js
│ │ └── styles.css
├── garment_inquire_utils.py
├── garment_lbs_utils.py
├── garment_utils_v2.py
├── garmentcodeRC_utils.py
├── garmentcode_utils.py
├── json_fixer.py
├── lisa_utils.py
├── mm_utils.py
├── model
│ ├── __init__.py
│ ├── apply_delta.py
│ ├── builder.py
│ ├── consolidate.py
│ ├── language_model
│ │ ├── llava_garment_float50.py
│ │ ├── llava_llama.py
│ │ ├── llava_mistral.py
│ │ └── llava_mpt.py
│ ├── llava_arch.py
│ ├── make_delta.py
│ ├── multimodal_encoder
│ │ ├── builder.py
│ │ └── clip_encoder.py
│ ├── multimodal_projector
│ │ └── builder.py
│ ├── smplx
│ │ ├── body_models.py
│ │ ├── joint_names.py
│ │ ├── lbs.py
│ │ ├── smplx_utils.py
│ │ ├── utils.py
│ │ ├── vertex_ids.py
│ │ └── vertex_joint_selector.py
│ └── utils.py
├── prompts_utils.py
├── pytorch3d_render_utils.py
├── serve
│ ├── __init__.py
│ ├── cli.py
│ ├── controller.py
│ ├── examples
│ │ ├── extreme_ironing.jpg
│ │ └── waterview.jpg
│ ├── gradio_web_server.py
│ ├── model_worker.py
│ ├── register_worker.py
│ ├── sglang_worker.py
│ └── test_message.py
├── train
│ ├── train_garmentcode_outfit.py
│ └── train_mem_garmentcode_outfit.py
└── utils.py
├── pyproject.toml
├── run_garmentcode_sim.py
└── scripts
├── evaluate_garment_v2_demo_edit_1float.py
├── evaluate_garment_v2_eva_edit_1float.py
├── evaluate_garment_v2_imggen_1float.py
├── evaluate_garment_v2_textgen_1float.py
├── evaluate_garment_v2_textgen_fromimg_1float.py
├── postprocess
├── grounding_sam.py
└── postprocess.py
├── v1_5
├── evaluate_garment_v2_demo_edit.sh
├── evaluate_garment_v2_eva_edit.sh
├── evaluate_garment_v2_imggen_2step.sh
├── evaluate_garment_v2_textgen.sh
├── evaluate_garment_v2_textgen_fromimg.sh
└── finetune_task_lora_garmentcode_outfit.sh
├── zero2.json
├── zero3.json
└── zero3_offload.json
/.dockerignore:
--------------------------------------------------------------------------------
1 | # The .dockerignore file excludes files from the container build process.
2 | #
3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file
4 |
5 | # Exclude Git files
6 | .git
7 | .github
8 | .gitignore
9 |
10 | # Exclude Python cache files
11 | __pycache__
12 | .mypy_cache
13 | .pytest_cache
14 | .ruff_cache
15 |
16 | # Exclude Python virtual environment
17 | /venv
18 |
19 | # Exclude some weights
20 | /openai
21 | /liuhaotian
22 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | # Unix-style newlines with a newline ending every file
4 | [*]
5 | end_of_line = lf
6 | insert_final_newline = true
7 | trim_trailing_whitespace = true
8 | charset = utf-8
9 |
10 | # 4 space indentation
11 | [*.{py,json}]
12 | indent_style = space
13 | indent_size = 4
14 |
15 | # 2 space indentation
16 | [*.{md,sh,yaml,yml}]
17 | indent_style = space
18 | indent_size = 2
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | # https://git-scm.com/docs/gitattributes
2 |
3 | # Set the default behavior, in case people don't have core.autocrlf set.
4 | # https://git-scm.com/docs/gitattributes#_end_of_line_conversion
5 | * text=auto
6 |
7 | # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes
8 | # Source files
9 | # ============
10 | *.pxd text diff=python
11 | *.py text diff=python
12 | *.py3 text diff=python
13 | *.pyw text diff=python
14 | *.pyx text diff=python
15 | *.pyz text diff=python
16 | *.pyi text diff=python
17 |
18 | # Binary files
19 | # ============
20 | *.db binary
21 | *.p binary
22 | *.pkl binary
23 | *.pickle binary
24 | *.pyc binary export-ignore
25 | *.pyo binary export-ignore
26 | *.pyd binary
27 |
28 | # Jupyter notebook
29 | *.ipynb text eol=lf
30 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/1-usage.yaml:
--------------------------------------------------------------------------------
1 | name: Usage issues
2 | description: Report issues in usage.
3 | title: "[Usage] "
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: |
8 | Thanks for taking the time to fill out this form. Please give as detailed description as possible for us to better assist with the issue :)
9 | - type: textarea
10 | id: what-happened
11 | attributes:
12 | label: Describe the issue
13 | description: Please give as detailed description as possible for us to better assist with the issue. Please paste the **FULL** error log here, so that we can better understand the issue. Wrap the log with ``` for better readability in GitHub.
14 | placeholder: Issue
15 | value: |
16 | Issue:
17 |
18 | Command:
19 | ```
20 | PASTE THE COMMANDS HERE.
21 | ```
22 |
23 | Log:
24 | ```
25 | PASTE THE LOGS HERE.
26 | ```
27 |
28 | Screenshots:
29 | You may attach screenshots if it better explains the issue.
30 | validations:
31 | required: true
32 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/2-feature-request.yaml:
--------------------------------------------------------------------------------
1 | name: Feature Request
2 | description: Request for a new feature
3 | title: "[Feature request] "
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: |
8 | Thanks for your interest in our work. Please share your thoughts of the new features below.
9 | - type: textarea
10 | id: feature
11 | attributes:
12 | label: feature
13 | placeholder: Start your thoughts here...
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/3-question.yaml:
--------------------------------------------------------------------------------
1 | name: Questions
2 | description: General questions about the work
3 | title: "[Question] "
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: |
8 | Thanks for your interest in our work. For this type of question, it may be more suitable to go to [discussion](https://github.com/haotian-liu/LLaVA/discussions) sections. If you believe an issue would be better for your request, please continue your post below :)
9 | - type: textarea
10 | id: question
11 | attributes:
12 | label: Question
13 | placeholder: Start question here...
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/4-discussion.yaml:
--------------------------------------------------------------------------------
1 | name: Discussions
2 | description: General discussions about the work
3 | title: "[Discussion] "
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: |
8 | Thanks for your interest in our work. For this type of question, it may be more suitable to go to [discussion](https://github.com/haotian-liu/LLaVA/discussions) sections. If you believe an issue would be better for your request, please continue your post below :)
9 | - type: textarea
10 | id: discussion
11 | attributes:
12 | label: Discussion
13 | placeholder: Start discussion here...
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__
3 | *.pyc
4 | *.egg-info
5 | dist
6 |
7 | # Log
8 | *.log
9 | *.log.*
10 | *.jsonl
11 |
12 | # Data
13 | !**/alpaca-data-conversation.json
14 |
15 | # Editor
16 | .idea
17 | *.swp
18 |
19 | # Other
20 | .DS_Store
21 | wandb
22 | output
23 |
24 | checkpoints
25 | ckpts*
26 |
27 | .ipynb_checkpoints
28 | *.ipynb
29 |
30 | # DevContainer
31 | !.devcontainer/*
32 |
33 | # Demo
34 | serve_images/
35 |
36 | runs
37 | assets
38 | playground
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
ChatGarment: Garment Estimation, Generation and Editing via Large Language Models
5 |
6 |
7 |

8 |
9 |
10 |
11 |
12 | This is the implementation of ChatGarment. More details please check our
13 | [[Project Page](https://chatgarment.github.io/)].
14 |
15 | ChatGarmen utilizes large vision-language models (VLMs) to automate the estimation, generation, and editing of 3D garments from images or text descriptions.
16 |
17 |
18 | ## Applications
19 |
20 | |  |
|
21 | | :--------------------: | :----------: |
22 | | Image-based Reconstruction | Text-based Generation |
23 | |  |  |
24 | | Text-based Editing | Text-based Editing |
25 |
26 |
27 | ## Relevant Repositories
28 | 1. [**GarmentCodeRC**](https://github.com/biansy000/GarmentCodeRC): A refined version of the original [GarmentCode](https://github.com/maria-korosteleva/GarmentCode), used by ChatGarment for garment generation.
29 |
30 | 2. [**ContourCraft-CG**](https://github.com/biansy000/ContourCraft-CG): A refined version of the original [ContourCraft](https://github.com/Dolorousrtur/ContourCraft), used by ChatGarment for garment simulation.
31 |
32 | 3. [**ChatGarmentDataset**](https://huggingface.co/datasets/sy000/ChatGarmentDataset): A Hugging Face dataset with training and inference data used in our paper.
33 |
34 |
35 | ## Installation
36 | The installation instructions are provided in ``docs/Installation.md``.
37 |
38 | ## Model Training
39 | The training data is available in [ChatGarmentDataset](https://huggingface.co/datasets/sy000/ChatGarmentDataset).
40 | ```Shell
41 | ./scripts/v1_5/finetune_task_lora_garmentcode_wholebody_combineT2.sh
42 | ```
43 |
44 | ## Model Inference
45 |
46 | #### 1. Image-based Reconstruction (CoT)
47 | ```Shell
48 | # Run image based reconstruction with CoT for images in example_data/example_imgs/
49 | # Detailed steps of the script:
50 | # 1. Accepts an input image.
51 | # 2. Utilizes ChatGarment Model to generate text prompts based on the image.
52 | # 3. Sends the ChatGarment-generated text & input image to ChatGarment Model again.
53 | # 4. Outputs the final GarmentCode sewing patterns.
54 | ./scripts/v1_5/evaluate_garment_v2_imggen_2step.sh example_data/example_imgs/
55 | ```
56 |
57 |
58 | #### 2. Text-based Generation
59 | ```Shell
60 | # Run text based generation for prompts given in the input JSON file
61 | # Detailed steps of the script:
62 | # 1. Accepts an input json file.
63 | # 2. Utilizes GPT-4o to generate well-formed text descriptions based on the original prompts.
64 | # 3. Sends the GPT-generated text to ChatGarment Model.
65 | # 4. Outputs the final GarmentCode sewing patterns.
66 | ./scripts/v1_5/evaluate_garment_v2_textgen.sh example_data/example_jsons/example_textgen_prompts.json
67 | ```
68 |
69 |
70 | #### 3. Garment Editing
71 | ```Shell
72 | # Run text based generation for prompts given in the input JSON file
73 | # Detailed steps of the script:
74 | # 1. Accepts an input json file.
75 | # 2. Utilizes GPT-4o to generate well-formed editing prompts based on the original prompts.
76 | # 3. Sends the GPT-generated text to ChatGarment Model.
77 | # 4. Outputs the final GarmentCode sewing patterns.
78 | ./scripts/v1_5/evaluate_garment_v2_demo_edit.sh example_data/example_jsons/example_edit_prompts.json
79 | ```
80 |
81 | #### 4. Multi-turn conversations.
82 | (Coming Soon)
83 |
84 |
85 | ## After Inference
86 |
87 | #### 1. Generate 3D Garments Based on ChatGarment Output
88 | After inference, ChatGarment outputs 2D sewing patterns and JSON configurations in the specified ``$(OUTPUT_DIR)``. The 2D patterns can then be stitched together to generate the corresponding 3D garments using the following code:
89 |
90 | ```Shell
91 | # Run garment stitching to get draped 3D garments
92 | # For example, $(OUTPUT_DIR) = runs/try_7b_lr1e_4_v3_garmentcontrol_4h100_v4_final_textgen_exampleimg
93 | python run_garmentcode_sim.py --all_paths_json $(OUTPUT_DIR)
94 | ```
95 |
96 | #### 2. (Optional) Postprocessing for More Accurate Sizes
97 | ChatGarment may occasionally produce garments with incorrect lengths or widths from input images. To alleviate this, we provide a postprocessing method that refines garment sizes. Detailed instructions are available in ``docs/postprocess.md``.
98 |
99 |
100 |
101 | ## Citation
102 | ```bibtex
103 | @article{bian2024chatgarment,
104 | title={ChatGarment: Garment Estimation, Generation and Editing via Large Language Models},
105 | author={Bian, Siyuan and Xu, Chenghao and Xiu, Yuliang and Grigorev, Artur and Liu, Zhen and Lu, Cewu and Black, Michael J and Feng, Yao},
106 | journal={arXiv preprint arXiv:2412.17811},
107 | year={2024}
108 | }
109 | ```
110 |
111 | ## Acknowledgments
112 | This repository is built extensively on top of [LLaVA](https://github.com/haotian-liu/LLaVA) and [LISA](https://github.com/dvlab-research/LISA).
113 |
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 |
4 | build:
5 | gpu: true
6 |
7 | python_version: "3.11"
8 |
9 | python_packages:
10 | - "torch==2.0.1"
11 | - "accelerate==0.21.0"
12 | - "bitsandbytes==0.41.0"
13 | - "deepspeed==0.9.5"
14 | - "einops-exts==0.0.4"
15 | - "einops==0.6.1"
16 | - "gradio==3.35.2"
17 | - "gradio_client==0.2.9"
18 | - "httpx==0.24.0"
19 | - "markdown2==2.4.10"
20 | - "numpy==1.26.0"
21 | - "peft==0.4.0"
22 | - "scikit-learn==1.2.2"
23 | - "sentencepiece==0.1.99"
24 | - "shortuuid==1.0.11"
25 | - "timm==0.6.13"
26 | - "tokenizers==0.13.3"
27 | - "torch==2.0.1"
28 | - "torchvision==0.15.2"
29 | - "transformers==4.31.0"
30 | - "wandb==0.15.12"
31 | - "wavedrom==2.0.3.post3"
32 | - "Pygments==2.16.1"
33 | run:
34 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget
35 |
36 | # predict.py defines how predictions are run on your model
37 | predict: "predict.py:Predictor"
38 |
--------------------------------------------------------------------------------
/docs/Installation.md:
--------------------------------------------------------------------------------
1 | ## Installation
2 |
3 | #### 1. Clone this repository
4 | ```bash
5 | git clone git@github.com:biansy000/ChatGarment.git
6 | cd ChatGarment
7 | ```
8 |
9 | #### 2. Install Dependencies
10 | If you are not using Linux, see instructions for [macOS](https://github.com/haotian-liu/LLaVA/blob/main/docs/macOS.md) and [Windows](https://github.com/haotian-liu/LLaVA/blob/main/docs/Windows.md).
11 |
12 | ```Shell
13 | conda create -n chatgarment python=3.10 -y
14 | conda activate chatgarment
15 | pip install --upgrade pip # enable PEP 660 support
16 | pip install -e ".[train]"
17 | pip install flash-attn --no-build-isolation
18 | ```
19 |
20 | #### 3. Install [GarmentCodeRC](https://github.com/biansy000/GarmentCodeRC)
21 | Follow installation instructions in its repository.
22 |
23 |
24 | #### 4. Download Pretrained Weights
25 | Put the [Pretrained weights](https://sjtueducn-my.sharepoint.com/:u:/g/personal/biansiyuan_sjtu_edu_cn/EQayoB8ie7ZIsFrjLWdBASQBFexZHXcGjrS6ghgGCjIMzw?e=o60Y65) to ``checkpoints/try_7b_lr1e_4_v3_garmentcontrol_4h100_v4_final/pytorch_model.bin``.
26 |
27 | #### 5. Update Paths in Code
28 | Modify the following lines in relevant Python files:
29 | ```Python
30 | sys.path.insert(1, '/is/cluster/fast/sbian/github/chatgarment_private') # path of the current ChatGarment repo
31 | sys.path.insert(1, '/is/cluster/fast/sbian/github/GarmentCodeV2/') # path of GarmentCodeRC repo
32 | ```
33 | Replace with their actual local paths.
34 |
35 | #### 6. Add Soft Link
36 | Add the softlink of ``assets`` folder in ``GarmentCodeRC`` repo:
37 | ```Shell
38 | ln -s path_to_garmentcode_assets assets
39 | ```
40 |
--------------------------------------------------------------------------------
/docs/all_float_paths.json:
--------------------------------------------------------------------------------
1 | ["design.waistband.waist", "design.waistband.width", "design.shirt.length", "design.shirt.width", "design.shirt.flare", "design.collar.width", "design.collar.fc_depth", "design.collar.bc_depth", "design.collar.f_bezier_x", "design.collar.f_bezier_y", "design.collar.b_bezier_x", "design.collar.b_bezier_y", "design.collar.component.hood_depth", "design.collar.component.hood_length", "design.sleeve.length", "design.sleeve.connecting_width", "design.sleeve.end_width", "design.sleeve.opening_dir_mix", "design.sleeve.standing_shoulder_len", "design.sleeve.connect_ruffle", "design.sleeve.smoothing_coeff", "design.sleeve.cuff.top_ruffle", "design.sleeve.cuff.cuff_len", "design.sleeve.cuff.skirt_fraction", "design.sleeve.cuff.skirt_flare", "design.sleeve.cuff.skirt_ruffle", "design.left.shirt.width", "design.left.shirt.flare", "design.left.collar.width", "design.left.collar.f_bezier_x", "design.left.collar.f_bezier_y", "design.left.collar.b_bezier_x", "design.left.collar.b_bezier_y", "design.left.sleeve.length", "design.left.sleeve.connecting_width", "design.left.sleeve.end_width", "design.left.sleeve.opening_dir_mix", "design.left.sleeve.standing_shoulder_len", "design.left.sleeve.connect_ruffle", "design.left.sleeve.smoothing_coeff", "design.left.sleeve.cuff.top_ruffle", "design.left.sleeve.cuff.cuff_len", "design.left.sleeve.cuff.skirt_fraction", "design.left.sleeve.cuff.skirt_flare", "design.left.sleeve.cuff.skirt_ruffle", "design.skirt.length", "design.skirt.rise", "design.skirt.ruffle", "design.skirt.bottom_cut", "design.flare-skirt.length", "design.flare-skirt.rise", "design.flare-skirt.suns", "design.flare-skirt.asymm.front_length", "design.flare-skirt.cut.depth", "design.flare-skirt.cut.width", "design.flare-skirt.cut.place", "design.pencil-skirt.length", "design.pencil-skirt.rise", "design.pencil-skirt.flare", "design.pencil-skirt.front_slit", "design.pencil-skirt.back_slit", "design.pencil-skirt.left_slit", "design.pencil-skirt.right_slit", "design.levels-skirt.level_ruffle", "design.levels-skirt.length", "design.levels-skirt.rise", "design.levels-skirt.base_length_frac", "design.pants.length", "design.pants.width", "design.pants.flare", "design.pants.rise", "design.pants.cuff.top_ruffle", "design.pants.cuff.cuff_len", "design.pants.cuff.skirt_fraction", "design.pants.cuff.skirt_flare", "design.pants.cuff.skirt_ruffle"]
--------------------------------------------------------------------------------
/docs/images/black_img.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/docs/images/black_img.jpg
--------------------------------------------------------------------------------
/docs/images/img_recon.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/docs/images/img_recon.gif
--------------------------------------------------------------------------------
/docs/images/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/docs/images/teaser.png
--------------------------------------------------------------------------------
/docs/images/text_generation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/docs/images/text_generation.png
--------------------------------------------------------------------------------
/docs/images/video_edit.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/docs/images/video_edit.gif
--------------------------------------------------------------------------------
/docs/images/video_edit_2.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/docs/images/video_edit_2.gif
--------------------------------------------------------------------------------
/docs/postprocess.md:
--------------------------------------------------------------------------------
1 | # Postprocessing after ChatGarment Inference
2 |
3 | ChatGarment may occasionally produce garments with incorrect lengths or widths from input images. To alleviate this, we provide a postprocessing method that refines garment sizes using a finite-difference-based approach. This process adjusts the garment length and width to better match the segmentation mask predicted by SAM (Segment Anything Model).
4 |
5 | Assume that the input images are placed in the folder ``example_data/example_imgs``.
6 |
7 | ### Step 1. Garment Segmentation with Grounding-SAM
8 | Install [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) and [segment-anything](https://github.com/facebookresearch/segment-anything) for segmentation. You can follow the installation instructions provided in [PuzzleAvatar](https://github.com/YuliangXiu/PuzzleAvatar/blob/main/scripts/install_dino_sam.sh)
9 |
10 | Run the segmentation script:
11 | ```bash
12 | python scripts/postprocess/grounding_sam.py --in_dir example_data/example_imgs --out_dir runs/example_eva_SAM
13 | ```
14 |
15 | ### Step 2. Human Pose and Shape Estimation with TokenHMR
16 | Install [TokenHMR](https://github.com/saidwivedi/TokenHMR) for human pose estimation. Navigate to the TokenHMR directory:
17 | ```bash
18 | cd PATH_TO_TOKENHMR
19 | ```
20 |
21 | Next, modify ``demo.py`` by inserting the following code after [this line](https://github.com/saidwivedi/TokenHMR/blob/198645f7784a27a4df0eac32478b1e7bc3e13574/tokenhmr/demo.py#L116):
22 | ```python
23 | out_saved = out.copy()
24 | out_saved['pred_cam_t_full'] = pred_cam_t_full[n]
25 | out_saved['scaled_focal_length'] = scaled_focal_length
26 | for k, v in out_saved['pred_smpl_params'].items():
27 | if isinstance(v, torch.Tensor):
28 | out_saved['pred_smpl_params'][k] = v.detach().cpu().numpy()
29 | with open(os.path.join(args.out_folder, f'{img_fn}_{person_id}.pkl'), 'wb') as f:
30 | pickle.dump(out_saved, f)
31 | ```
32 |
33 | Then, run TokenHMR with the following command:
34 | ```bash
35 | python tokenhmr/demo.py \
36 | --img_folder {PATH_TO_CCHATGARMENT}/runs/example_eva_SAM/imgs_upsampled \
37 | --batch_size=1 \
38 | --full_frame \
39 | --checkpoint data/checkpoints/tokenhmr_model_latest.ckpt \
40 | --model_config data/checkpoints/model_config.yaml \
41 | --out_folder {PATH_TO_CCHATGARMENT}/runs/example_eva_SAM/tokenhmr_output
42 | ```
43 |
44 | ### Step 3. Install Extra Packages
45 | * Pytorch3D: Follow the official [installation guide](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md).
46 | * Chumpy: Install with pip: ``pip install chumpy``.
47 | Then, comment out the following line in ``chumpy/__init__.py``:
48 |
49 | ```python
50 | from numpy import bool, int, float, complex, object, unicode, str, nan, inf
51 | ```
52 |
53 |
54 | ### Step 4. Run the Postprocessing Script
55 | Assume you ChatGarment inference results in ``runs/try_7b_lr1e_4_v3_garmentcontrol_4h100_v4_final_eva/vis_new/``. Download the required [extra-data](https://drive.google.com/file/d/1QXezA3J6uXqWHGATmcw3jaYxRXY2Ctte/view?usp=sharing) and extract it to ``checkpoints/extra_data``. Now run the postprocessing script. For example, to process the image:``1aee14a8c7b4d56b4e8b6ddd575d1f561a72fdc75c43a4b6926f1655152193c6.png``, use:
56 | ```bash
57 | python scripts/postprocess/postprocess.py --imgname 1aee14a8c7b4d56b4e8b6ddd575d1f561a72fdc75c43a4b6926f1655152193c6 \
58 | --img_dir runs/example_eva_SAM/imgs_upsampled \
59 | --inp_pose_params_dir runs/example_eva_SAM/tokenhmr_output \
60 | --garmentcode_dir runs/try_7b_lr1e_4_v3_garmentcontrol_4h100_v4_final/example_imgs_img_recon/vis_new/ \
61 | --saved_dir runs/example_eva_SAM/postprocess \
62 | --garment_seg_dir runs/example_eva_SAM/mask/
63 | ```
64 |
--------------------------------------------------------------------------------
/docs/prompts/detailed_textbased_description.txt:
--------------------------------------------------------------------------------
1 | I will provide some text descriptions of a [TYPE]. Describe the garment based on these texts.
2 |
3 | You should generate a LIST of THREE strings.
4 |
5 | In the first string, describe the garment type (If THE SUBJECT HAS A NAME, INCLUDE ITS NAME FIRST!);
6 |
7 | Example phrases for the first string: "hood", "T-shirt", "jacket", "tuxedo", etc.
8 |
9 |
10 | In the second string, describe the structures of the garment (DO NOT INCLUDE ANY INFO ABOUT THE HUMAN MODEL AND THE COLOR OF THE GARMENT) in the format of a dict. You should include the most common structures of a [TYPE] even if they are not specified in the text descriptions.
11 |
12 | Select the keys from the following list:
13 | ['width', 'length', 'sleeves', 'pant legs', 'waist', 'dress', 'skirt hems', 'collar', 'hood', 'waist', ... ]
14 |
15 | In the value of the dict, please use several different short phrases in a list with the following tips:
16 |
17 | Describe the width of the garment: wide, normal, narrow, etc.
18 | Describe the length of the garment: long, normal, short, etc.
19 | Describe the length and width of the sleeves: long, normal, short, tight, loose sleeveless, etc.
20 | Describe the detailed struture of the sleeves. Example: "asymmetrical sleeves", "straight sleeves", "puff sleeves", "three-quater sleeves", "accordion sleeves", etc.
21 | Describe the length and width of the legs of trousers: long, normal, short, tight, loose legs, etc.
22 | Describe the detailed struture of the pant legs. Example: "asymmetrical legs", "straight legs", "flared legs", "cropped legs", "cuffed legs", etc.
23 | Describe the length and width of the dress: long dress, normal dress, short dress, tight dress, loose dress, etc.
24 | Describe the detailed struture of the skirt hems. Example: "straight hem", "A-line hem", "pleated hem", "pencil hem", "slit hem", etc.
25 | Describe the detailed struture of the neck or collar. Example: "crew neck", "V-neck", "turtle neck", "collarless", etc.
26 | Describe the detailed struture of the hood. Example: "normal hood", "cape hood", "cowl hood", etc.
27 |
28 | An example of the dict description for a T-shirt is:
29 | {
30 | 'width': ['wide'],
31 | 'length': ['normal'],
32 | 'sleeves': ['elbow-length sleeves', 'tight sleeves', 'accordion sleeves'],
33 | 'collar': ['crew neck'],
34 | 'hood': ['no hood']
35 | }
36 |
37 | An example of the dict description for a skirt is:
38 | {
39 | 'width': ['wide'],
40 | 'length': ['knee-length'],
41 | 'waist': ['high waist'],
42 | 'skirt hems': ['pencil hem', 'pleated hem']
43 | }
44 |
45 | In the third string, describe the extra detailed structures of the garment (DO NOT INCLUDE ANY INFO ABOUT THE HUMAN MODEL AND THE COLOR OR PATTERN OF THE GARMENT) that are missing in the second string using several different short phrases split by ','. Example phrases for the third string: "pleated skirt", "high-waist", "zipper closure", "frayed hem", "mid-rise waist", etc. If there is no extra structures, return an empty string.
46 |
47 | Please strictly avoid mentioning color, texture, and material.
48 |
49 | Return the results in the following format: [garment type, garment geometric features, extra features]. Only return the JSON List in the above format without adding explanations.
50 |
51 | The text description is: [DESCRIPTION]
--------------------------------------------------------------------------------
/docs/prompts/gpt4v_prompt_garment_sam.txt:
--------------------------------------------------------------------------------
1 | Analyze the provided images, each featuring an individual. Identify and describe the individual's garments like shirts, outer coats, hats, pants, shoes, dresses, skirts, scarves, etc. Return the results in a dictionary format as follows: {"shirt": shirt description, "dress": dress description, "skirt": skirt description, "pants": pants description, "shoes": shoes description, "outer coat": outer coat description...}. The "description" should be one or two noun/adj words that describe the topological or geometric features, such as length (short/long), shape or style, without referencing color or texture pattern. Exclude accessories like belts, watch, badges, and etc. Remove the key if the garment does not appear, or the value string is empty (""), only keep the visible garments, do not describe colors, and ensure no garment is described within the description of another (e.g., {"pants": "long dress"}). All strings should be enclosed in double quotes. The response should only contain the dictionary, without additional sentences, explanations, or markdowns.
--------------------------------------------------------------------------------
/docs/prompts/prompt_garment_editing.txt:
--------------------------------------------------------------------------------
1 | I will provide text prompts to edit some specific garment parts of the [TYPE]. Based on the prompt and the image of the original garment, generate a structured garment part description in a Python dict format.
2 |
3 | The possible editable parts are: ['waistband', 'shirt main body panel', 'collar and neckline', 'sleeves', 'sleeve_cuff', 'skirt', 'pants', 'pant_cuff', ...]
4 |
5 | Text Prompt:
6 | [DESCRIPTION]
7 |
8 | Output Format:
9 | Only return a JSON dict in the format: ``{part-name-1: [geometry feature 1, geometry feature 2, geometry feature 3, ...], part-name-2: [...], ...}``, where ``part-name-1`` and ``part-name-2`` are names of the edited garment parts, and ``[geometry feature 1, geometry feature 2, geometry feature 3, ...]`` are features of the garment part After editing. Please ONLY focus on the geometric feature. Strictly avoid mentioning color, texture, seams, and material. Exclude garment parts that remain unchanged.
--------------------------------------------------------------------------------
/docs/prompts/prompt_garment_part_inference.txt:
--------------------------------------------------------------------------------
1 | I will provide an image of human models wearing the [GARMENT], and please focus on the [PART] on their [GARMENT].
2 |
3 | Please describe the geometry features of all the [PART] on the [GARMENT]. Please only describe geometries and structures of [PART]. Strictly avoid mentioning [DONOT] or other garment parts. Do not describe features not shared by all garment, and strictly avoid mentioning color, texture, seams, and material.
4 |
5 | Return a Json LIST of several phrases, each describing a geometric feature of the [PART], in the Json list format: [geometry feature 1, geometry feature 2, geometry feature 3, ...].
--------------------------------------------------------------------------------
/docs/prompts/smplified_image_description.txt:
--------------------------------------------------------------------------------
1 | I will provide an image of a human model wearing several garments. Describe the outer layer garments the model is wearing. In the image, the model may wear one upper garment and one lower garment, or the model may wear a single wholebody garment. Avoid describing extra accessories such as the scarves, socks, watch, badges, and etc.
2 |
3 | For each garment, you should generate TWO strings.
4 |
5 | In the first string, describe the garment type (If THE SUBJECT HAS A NAME, INCLUDE ITS NAME FIRST!);
6 |
7 | Example phrases for the first string: "hood", "T-shirt", "jacket", "tuxedo", etc.
8 |
9 |
10 | In the second string, describe the overall global geometric features of the garment (DO NOT INCLUDE ANY INFO ABOUT THE HUMAN MODEL AND THE COLOR INFO OF THE GARMENT) using several different short phrases split by ',' with the following tips:
11 |
12 | Example rules:
13 | Describe the length of the sleeves: long, normal, short, sleeveless, etc.
14 | Describe if it has a hood: with a hood, etc.
15 | Describe the length of the dress: long, normal, short, etc.
16 | Describe the width of the garment: wide, normal, narrow, etc.
17 | Describe the length of the garment: long, normal, short, etc.
18 | Describe the length of the legs of trousers: long, normal, short, etc.
19 |
20 | Please follow the example rules above (not limited to these examples) to describe the geometric features of the garment.
21 |
22 | Example phrases for the second string: "long sleeves", "wide garment", "with a hood", "deep collar", "sleeveless"...
23 |
24 |
25 | Please strictly avoid mentioning color, texture, and material.
26 |
27 | In the image, if the model is wearing one upper garment and one lower garment, return the results in the following format: {"upper garment": [upper garment type, upper garment geometric features], "lower garment": [lower garment type, lower garment geometric features]}. Otherwise, the model is wearing a single wholebody garment , return the results in the following format: {"wholebody garment": [wholebody garment type, wholebody garment geometric features]}. Return only the JSON dictionary in the above format with a length of 1 or 2.
--------------------------------------------------------------------------------
/example_data/example_imgs/1aee14a8c7b4d56b4e8b6ddd575d1f561a72fdc75c43a4b6926f1655152193c6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/1aee14a8c7b4d56b4e8b6ddd575d1f561a72fdc75c43a4b6926f1655152193c6.png
--------------------------------------------------------------------------------
/example_data/example_imgs/1dde6afed43187fe927089a615e3f744724ef3defdf3f2ae4a6cede5ad71dcea.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/1dde6afed43187fe927089a615e3f744724ef3defdf3f2ae4a6cede5ad71dcea.png
--------------------------------------------------------------------------------
/example_data/example_imgs/62bb809fc2dcd50409cb36163a0eb222f9aa1af0f256a3233b67b3ed4081dc71.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/62bb809fc2dcd50409cb36163a0eb222f9aa1af0f256a3233b67b3ed4081dc71.png
--------------------------------------------------------------------------------
/example_data/example_imgs/6fe14e1f646513ee93714fbe8026a84c6a2897be4df2f3c936cb2be8dd2d1762.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/6fe14e1f646513ee93714fbe8026a84c6a2897be4df2f3c936cb2be8dd2d1762.png
--------------------------------------------------------------------------------
/example_data/example_imgs/72b086429d2dfe2a8de6f4403a024b2bb17446021c9e8f9ebacfc7a990ac8434.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/72b086429d2dfe2a8de6f4403a024b2bb17446021c9e8f9ebacfc7a990ac8434.png
--------------------------------------------------------------------------------
/example_data/example_imgs/80141ce740f489f1d2f57a03f32c7577a28b62a6ac790a0d9ed8a18d961c2918.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/80141ce740f489f1d2f57a03f32c7577a28b62a6ac790a0d9ed8a18d961c2918.png
--------------------------------------------------------------------------------
/example_data/example_imgs/8e3c458da20c290c216813ec07f1a2e8f9cfb4ee7e412a783a238ec353b346a0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/8e3c458da20c290c216813ec07f1a2e8f9cfb4ee7e412a783a238ec353b346a0.png
--------------------------------------------------------------------------------
/example_data/example_imgs/c2b582eb318455abaf8ed8e3126c1b423ade2704d810f7cd24428febda5632fa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/c2b582eb318455abaf8ed8e3126c1b423ade2704d810f7cd24428febda5632fa.png
--------------------------------------------------------------------------------
/example_data/example_imgs/d77c6f5d4856831878eadb7fe3c8b180bfa9e9ad4a14936ac10a1697bb3c054f.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/d77c6f5d4856831878eadb7fe3c8b180bfa9e9ad4a14936ac10a1697bb3c054f.png
--------------------------------------------------------------------------------
/example_data/example_imgs/e918651cc154a7570e47d8b8f6c0f0f93cfbb7d5129103a1bacd8299ba945f91.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/e918651cc154a7570e47d8b8f6c0f0f93cfbb7d5129103a1bacd8299ba945f91.png
--------------------------------------------------------------------------------
/example_data/example_jsons/example_edit_prompts.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "id": "001",
4 | "garmenttype": "upperbody garment",
5 | "image": "example_data/example_sewing_patterns/example_shirt/valid_garment_upper_render_front.png",
6 | "prompt": "Adjust the neckline to a classic crew neck and change to a sleeveless shirt",
7 | "json_path": "example_data/example_sewing_patterns/example_shirt/design.yaml"
8 | }
9 | ]
--------------------------------------------------------------------------------
/example_data/example_jsons/example_textgen_prompts.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "id": "001",
4 | "upperbody garment": {
5 | "name": "shirt",
6 | "text": "A crew-neck short-sleeve shirt"
7 | },
8 | "lowerbody garment": {
9 | "name": "pants",
10 | "text": "A pair of long, loose pants"
11 | }
12 | },
13 | {
14 | "id": "002",
15 | "upperbody garment": {
16 | "name": "shirt",
17 | "text": "A V-neck sleeveless shirt"
18 | },
19 | "lowerbody garment": {
20 | "name": "pants",
21 | "text": "A pair of shorts"
22 | }
23 | },
24 | {
25 | "id": "003",
26 | "wholebody garment": {
27 | "name": "dress",
28 | "text": "A long-sleeve dress"
29 | }
30 | }
31 | ]
--------------------------------------------------------------------------------
/example_data/example_sewing_patterns/example_shirt/valid_garment_upper_render_front.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_sewing_patterns/example_shirt/valid_garment_upper_render_front.png
--------------------------------------------------------------------------------
/llava/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import LlavaLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/llava/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 | IMAGE_PLACEHOLDER = ""
14 |
--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import tqdm
7 | import ray
8 | import time
9 |
10 | NUM_SECONDS_TO_SLEEP = 3
11 |
12 | @ray.remote(num_cpus=4)
13 | def get_eval(content: str, max_tokens: int):
14 | while True:
15 | try:
16 | response = openai.ChatCompletion.create(
17 | model='gpt-4',
18 | messages=[{
19 | 'role': 'system',
20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
21 | }, {
22 | 'role': 'user',
23 | 'content': content,
24 | }],
25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
26 | max_tokens=max_tokens,
27 | )
28 | break
29 | except openai.error.RateLimitError:
30 | pass
31 | except Exception as e:
32 | print(e)
33 | time.sleep(NUM_SECONDS_TO_SLEEP)
34 |
35 | print('success!')
36 | return response['choices'][0]['message']['content']
37 |
38 |
39 | def parse_score(review):
40 | try:
41 | score_pair = review.split('\n')[0]
42 | score_pair = score_pair.replace(',', ' ')
43 | sp = score_pair.split(' ')
44 | if len(sp) == 2:
45 | return [float(sp[0]), float(sp[1])]
46 | else:
47 | print('error', review)
48 | return [-1, -1]
49 | except Exception as e:
50 | print(e)
51 | print('error', review)
52 | return [-1, -1]
53 |
54 |
55 | if __name__ == '__main__':
56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
57 | parser.add_argument('-q', '--question')
58 | # parser.add_argument('-a', '--answer')
59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
60 | parser.add_argument('-r', '--rule')
61 | parser.add_argument('-o', '--output')
62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
63 | args = parser.parse_args()
64 |
65 | ray.init()
66 |
67 | f_q = open(os.path.expanduser(args.question))
68 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
69 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
71 |
72 | review_file = open(f'{args.output}', 'w')
73 |
74 | js_list = []
75 | handles = []
76 | idx = 0
77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
78 | # if idx == 1:
79 | # break
80 |
81 | ques = json.loads(ques_js)
82 | ans1 = json.loads(ans1_js)
83 | ans2 = json.loads(ans2_js)
84 |
85 | category = json.loads(ques_js)['category']
86 | if category in rule_dict:
87 | rule = rule_dict[category]
88 | else:
89 | rule = rule_dict['default']
90 | prompt = rule['prompt']
91 | role = rule['role']
92 | content = (f'[Question]\n{ques["text"]}\n\n'
93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
95 | f'[System]\n{prompt}\n\n')
96 | js_list.append({
97 | 'id': idx+1,
98 | 'question_id': ques['question_id'],
99 | 'answer1_id': ans1['answer_id'],
100 | 'answer2_id': ans2['answer_id'],
101 | 'category': category})
102 | idx += 1
103 | handles.append(get_eval.remote(content, args.max_tokens))
104 | # To avoid the rate limit set by OpenAI
105 | time.sleep(NUM_SECONDS_TO_SLEEP)
106 |
107 | reviews = ray.get(handles)
108 | for idx, review in enumerate(reviews):
109 | scores = parse_score(review)
110 | js_list[idx]['content'] = review
111 | js_list[idx]['tuple'] = scores
112 | review_file.write(json.dumps(js_list[idx]) + '\n')
113 | review_file.close()
114 |
--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review_bench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import time
7 |
8 | NUM_SECONDS_TO_SLEEP = 0.5
9 |
10 |
11 | def get_eval(content: str, max_tokens: int):
12 | while True:
13 | try:
14 | response = openai.ChatCompletion.create(
15 | model='gpt-4-0314',
16 | messages=[{
17 | 'role': 'system',
18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19 | }, {
20 | 'role': 'user',
21 | 'content': content,
22 | }],
23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
24 | max_tokens=max_tokens,
25 | )
26 | break
27 | except openai.error.RateLimitError:
28 | pass
29 | except Exception as e:
30 | print(e)
31 | time.sleep(NUM_SECONDS_TO_SLEEP)
32 |
33 | return response['choices'][0]['message']['content']
34 |
35 |
36 | def parse_score(review):
37 | try:
38 | score_pair = review.split('\n')[0]
39 | score_pair = score_pair.replace(',', ' ')
40 | sp = score_pair.split(' ')
41 | if len(sp) == 2:
42 | return [float(sp[0]), float(sp[1])]
43 | else:
44 | print('error', review)
45 | return [-1, -1]
46 | except Exception as e:
47 | print(e)
48 | print('error', review)
49 | return [-1, -1]
50 |
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54 | parser.add_argument('-q', '--question')
55 | parser.add_argument('-c', '--context')
56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57 | parser.add_argument('-r', '--rule')
58 | parser.add_argument('-o', '--output')
59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60 | args = parser.parse_args()
61 |
62 | f_q = open(os.path.expanduser(args.question))
63 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66 |
67 | if os.path.isfile(os.path.expanduser(args.output)):
68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69 | else:
70 | cur_reviews = []
71 |
72 | review_file = open(f'{args.output}', 'a')
73 |
74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75 | image_to_context = {context['image']: context for context in context_list}
76 |
77 | handles = []
78 | idx = 0
79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80 | ques = json.loads(ques_js)
81 | ans1 = json.loads(ans1_js)
82 | ans2 = json.loads(ans2_js)
83 |
84 | inst = image_to_context[ques['image']]
85 |
86 | if isinstance(inst['caption'], list):
87 | cap_str = '\n'.join(inst['caption'])
88 | else:
89 | cap_str = inst['caption']
90 |
91 | category = 'llava_bench_' + json.loads(ques_js)['category']
92 | if category in rule_dict:
93 | rule = rule_dict[category]
94 | else:
95 | assert False, f"Visual QA category not found in rule file: {category}."
96 | prompt = rule['prompt']
97 | role = rule['role']
98 | content = (f'[Context]\n{cap_str}\n\n'
99 | f'[Question]\n{ques["text"]}\n\n'
100 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
101 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
102 | f'[System]\n{prompt}\n\n')
103 | cur_js = {
104 | 'id': idx+1,
105 | 'question_id': ques['question_id'],
106 | 'answer1_id': ans1.get('answer_id', ans1['question_id']),
107 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
108 | 'category': category
109 | }
110 | if idx >= len(cur_reviews):
111 | review = get_eval(content, args.max_tokens)
112 | scores = parse_score(review)
113 | cur_js['content'] = review
114 | cur_js['tuple'] = scores
115 | review_file.write(json.dumps(cur_js) + '\n')
116 | review_file.flush()
117 | else:
118 | print(f'Skipping {idx} as we already have it.')
119 | idx += 1
120 | print(idx)
121 | review_file.close()
122 |
--------------------------------------------------------------------------------
/llava/eval/eval_gpt_review_visual.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | import openai
6 | import time
7 |
8 | NUM_SECONDS_TO_SLEEP = 0.5
9 |
10 |
11 | def get_eval(content: str, max_tokens: int):
12 | while True:
13 | try:
14 | response = openai.ChatCompletion.create(
15 | model='gpt-4-0314',
16 | messages=[{
17 | 'role': 'system',
18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.'
19 | }, {
20 | 'role': 'user',
21 | 'content': content,
22 | }],
23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation
24 | max_tokens=max_tokens,
25 | )
26 | break
27 | except openai.error.RateLimitError:
28 | pass
29 | except Exception as e:
30 | print(e)
31 | time.sleep(NUM_SECONDS_TO_SLEEP)
32 |
33 | return response['choices'][0]['message']['content']
34 |
35 |
36 | def parse_score(review):
37 | try:
38 | score_pair = review.split('\n')[0]
39 | score_pair = score_pair.replace(',', ' ')
40 | sp = score_pair.split(' ')
41 | if len(sp) == 2:
42 | return [float(sp[0]), float(sp[1])]
43 | else:
44 | print('error', review)
45 | return [-1, -1]
46 | except Exception as e:
47 | print(e)
48 | print('error', review)
49 | return [-1, -1]
50 |
51 |
52 | if __name__ == '__main__':
53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
54 | parser.add_argument('-q', '--question')
55 | parser.add_argument('-c', '--context')
56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[])
57 | parser.add_argument('-r', '--rule')
58 | parser.add_argument('-o', '--output')
59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
60 | args = parser.parse_args()
61 |
62 | f_q = open(os.path.expanduser(args.question))
63 | f_ans1 = open(os.path.expanduser(args.answer_list[0]))
64 | f_ans2 = open(os.path.expanduser(args.answer_list[1]))
65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r'))
66 |
67 | if os.path.isfile(os.path.expanduser(args.output)):
68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))]
69 | else:
70 | cur_reviews = []
71 |
72 | review_file = open(f'{args.output}', 'a')
73 |
74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))]
75 | image_to_context = {context['image']: context for context in context_list}
76 |
77 | handles = []
78 | idx = 0
79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2):
80 | ques = json.loads(ques_js)
81 | ans1 = json.loads(ans1_js)
82 | ans2 = json.loads(ans2_js)
83 |
84 | inst = image_to_context[ques['image']]
85 | cap_str = '\n'.join(inst['captions'])
86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']])
87 |
88 | category = json.loads(ques_js)['category']
89 | if category in rule_dict:
90 | rule = rule_dict[category]
91 | else:
92 | assert False, f"Visual QA category not found in rule file: {category}."
93 | prompt = rule['prompt']
94 | role = rule['role']
95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n'
96 | f'[Question]\n{ques["text"]}\n\n'
97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n'
98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n'
99 | f'[System]\n{prompt}\n\n')
100 | cur_js = {
101 | 'id': idx+1,
102 | 'question_id': ques['question_id'],
103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']),
104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']),
105 | 'category': category
106 | }
107 | if idx >= len(cur_reviews):
108 | review = get_eval(content, args.max_tokens)
109 | scores = parse_score(review)
110 | cur_js['content'] = review
111 | cur_js['tuple'] = scores
112 | review_file.write(json.dumps(cur_js) + '\n')
113 | review_file.flush()
114 | else:
115 | print(f'Skipping {idx} as we already have it.')
116 | idx += 1
117 | print(idx)
118 | review_file.close()
119 |
--------------------------------------------------------------------------------
/llava/eval/eval_pope.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | def eval_pope(answers, label_file):
6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')]
7 |
8 | for answer in answers:
9 | text = answer['text']
10 |
11 | # Only keep the first sentence
12 | if text.find('.') != -1:
13 | text = text.split('.')[0]
14 |
15 | text = text.replace(',', '')
16 | words = text.split(' ')
17 | if 'No' in words or 'not' in words or 'no' in words:
18 | answer['text'] = 'no'
19 | else:
20 | answer['text'] = 'yes'
21 |
22 | for i in range(len(label_list)):
23 | if label_list[i] == 'no':
24 | label_list[i] = 0
25 | else:
26 | label_list[i] = 1
27 |
28 | pred_list = []
29 | for answer in answers:
30 | if answer['text'] == 'no':
31 | pred_list.append(0)
32 | else:
33 | pred_list.append(1)
34 |
35 | pos = 1
36 | neg = 0
37 | yes_ratio = pred_list.count(1) / len(pred_list)
38 |
39 | TP, TN, FP, FN = 0, 0, 0, 0
40 | for pred, label in zip(pred_list, label_list):
41 | if pred == pos and label == pos:
42 | TP += 1
43 | elif pred == pos and label == neg:
44 | FP += 1
45 | elif pred == neg and label == neg:
46 | TN += 1
47 | elif pred == neg and label == pos:
48 | FN += 1
49 |
50 | print('TP\tFP\tTN\tFN\t')
51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN))
52 |
53 | precision = float(TP) / float(TP + FP)
54 | recall = float(TP) / float(TP + FN)
55 | f1 = 2*precision*recall / (precision + recall)
56 | acc = (TP + TN) / (TP + TN + FP + FN)
57 | print('Accuracy: {}'.format(acc))
58 | print('Precision: {}'.format(precision))
59 | print('Recall: {}'.format(recall))
60 | print('F1 score: {}'.format(f1))
61 | print('Yes ratio: {}'.format(yes_ratio))
62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) )
63 |
64 | if __name__ == "__main__":
65 | parser = argparse.ArgumentParser()
66 | parser.add_argument("--annotation-dir", type=str)
67 | parser.add_argument("--question-file", type=str)
68 | parser.add_argument("--result-file", type=str)
69 | args = parser.parse_args()
70 |
71 | questions = [json.loads(line) for line in open(args.question_file)]
72 | questions = {question['question_id']: question for question in questions}
73 | answers = [json.loads(q) for q in open(args.result_file)]
74 | for file in os.listdir(args.annotation_dir):
75 | assert file.startswith('coco_pope_')
76 | assert file.endswith('.json')
77 | category = file[10:-5]
78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category]
79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers)))
80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file))
81 | print("====================================")
82 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 |
7 |
8 | def get_args():
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--base-dir', type=str)
11 | parser.add_argument('--result-file', type=str)
12 | parser.add_argument('--output-file', type=str)
13 | parser.add_argument('--output-result', type=str)
14 | parser.add_argument('--split', type=str, default='test')
15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16 | return parser.parse_args()
17 |
18 |
19 | def convert_caps(results):
20 | fakecaps = []
21 | for result in results:
22 | image_id = result['question_id']
23 | caption = result['text']
24 | fakecaps.append({"image_id": int(image_id), "caption": caption})
25 | return fakecaps
26 |
27 |
28 | def get_pred_idx(prediction, choices, options):
29 | """
30 | Get the index (e.g. 2) from the prediction (e.g. 'C')
31 | """
32 | if prediction in options[:len(choices)]:
33 | return options.index(prediction)
34 | else:
35 | return -1
36 | return random.choice(range(len(choices)))
37 |
38 |
39 | if __name__ == "__main__":
40 | args = get_args()
41 |
42 | base_dir = args.base_dir
43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
44 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
45 | predictions = [json.loads(line) for line in open(args.result_file)]
46 | predictions = {pred['question_id']: pred for pred in predictions}
47 | split_problems = {idx: problems[idx] for idx in split_indices}
48 |
49 | results = {'correct': [], 'incorrect': []}
50 | sqa_results = {}
51 | sqa_results['acc'] = None
52 | sqa_results['correct'] = None
53 | sqa_results['count'] = None
54 | sqa_results['results'] = {}
55 | sqa_results['outputs'] = {}
56 |
57 | for prob_id, prob in split_problems.items():
58 | if prob_id not in predictions:
59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'}
60 | pred_text = 'FAILED'
61 | else:
62 | pred = predictions[prob_id]
63 | pred_text = pred['text']
64 |
65 | if pred_text in args.options:
66 | answer = pred_text
67 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ":
68 | answer = pred_text[0]
69 | else:
70 | pattern = re.compile(r'The answer is ([A-Z]).')
71 | res = pattern.findall(pred_text)
72 | if len(res) == 1:
73 | answer = res[0] # 'A', 'B', ...
74 | else:
75 | answer = "FAILED"
76 |
77 | pred_idx = get_pred_idx(answer, prob['choices'], args.options)
78 |
79 | analysis = {
80 | 'question_id': prob_id,
81 | 'parsed_ans': answer,
82 | 'ground_truth': args.options[prob['answer']],
83 | 'question': pred['prompt'],
84 | 'pred': pred_text,
85 | 'is_multimodal': '' in pred['prompt'],
86 | }
87 |
88 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options)
89 | sqa_results['outputs'][prob_id] = pred_text
90 |
91 | if pred_idx == prob['answer']:
92 | results['correct'].append(analysis)
93 | else:
94 | results['incorrect'].append(analysis)
95 |
96 | correct = len(results['correct'])
97 | total = len(results['correct']) + len(results['incorrect'])
98 |
99 | ###### IMG ######
100 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']])
101 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']])
102 | multimodal_total = multimodal_correct + multimodal_incorrect
103 | ###### IMG ######
104 |
105 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%')
106 |
107 | sqa_results['acc'] = correct / total * 100
108 | sqa_results['correct'] = correct
109 | sqa_results['count'] = total
110 |
111 | with open(args.output_file, 'w') as f:
112 | json.dump(results, f, indent=2)
113 | with open(args.output_result, 'w') as f:
114 | json.dump(sqa_results, f, indent=2)
115 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa_gpt4.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 | from collections import defaultdict
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--base-dir', type=str)
12 | parser.add_argument('--gpt4-result', type=str)
13 | parser.add_argument('--our-result', type=str)
14 | parser.add_argument('--split', type=str, default='test')
15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
16 | return parser.parse_args()
17 |
18 |
19 | def convert_caps(results):
20 | fakecaps = []
21 | for result in results:
22 | image_id = result['question_id']
23 | caption = result['text']
24 | fakecaps.append({"image_id": int(image_id), "caption": caption})
25 | return fakecaps
26 |
27 |
28 | def get_pred_idx(prediction, choices, options):
29 | """
30 | Get the index (e.g. 2) from the prediction (e.g. 'C')
31 | """
32 | if prediction in options[:len(choices)]:
33 | return options.index(prediction)
34 | else:
35 | return random.choice(range(len(choices)))
36 |
37 |
38 | if __name__ == "__main__":
39 | args = get_args()
40 |
41 | base_dir = args.base_dir
42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
43 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
44 | our_predictions = [json.loads(line) for line in open(args.our_result)]
45 | our_predictions = {pred['question_id']: pred for pred in our_predictions}
46 | split_problems = {idx: problems[idx] for idx in split_indices}
47 |
48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
49 |
50 | results = defaultdict(lambda: 0)
51 |
52 | for prob_id, prob in split_problems.items():
53 | if prob_id not in our_predictions:
54 | continue
55 | if prob_id not in gpt4_predictions:
56 | continue
57 | our_pred = our_predictions[prob_id]['text']
58 | gpt4_pred = gpt4_predictions[prob_id]
59 |
60 | pattern = re.compile(r'The answer is ([A-Z]).')
61 | our_res = pattern.findall(our_pred)
62 | if len(our_res) == 1:
63 | our_answer = our_res[0] # 'A', 'B', ...
64 | else:
65 | our_answer = "FAILED"
66 | gpt4_res = pattern.findall(gpt4_pred)
67 | if len(gpt4_res) == 1:
68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ...
69 | else:
70 | gpt4_answer = "FAILED"
71 |
72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
74 |
75 | if gpt4_answer == 'FAILED':
76 | results['gpt4_failed'] += 1
77 | # continue
78 | gpt4_pred_idx = our_pred_idx
79 | # if our_pred_idx != prob['answer']:
80 | # print(our_predictions[prob_id]['prompt'])
81 | # print('-----------------')
82 | # print(f'LECTURE: {prob["lecture"]}')
83 | # print(f'SOLUTION: {prob["solution"]}')
84 | # print('=====================')
85 | else:
86 | # continue
87 | pass
88 | # gpt4_pred_idx = our_pred_idx
89 |
90 | if gpt4_pred_idx == prob['answer']:
91 | results['correct'] += 1
92 | else:
93 | results['incorrect'] += 1
94 |
95 |
96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
97 | results['correct_upperbound'] += 1
98 |
99 | correct = results['correct']
100 | total = results['correct'] + results['incorrect']
101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%')
102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
104 |
105 |
--------------------------------------------------------------------------------
/llava/eval/eval_science_qa_gpt4_requery.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import re
5 | import random
6 | from collections import defaultdict
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--base-dir', type=str)
12 | parser.add_argument('--gpt4-result', type=str)
13 | parser.add_argument('--requery-result', type=str)
14 | parser.add_argument('--our-result', type=str)
15 | parser.add_argument('--output-result', type=str)
16 | parser.add_argument('--split', type=str, default='test')
17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
18 | return parser.parse_args()
19 |
20 |
21 | def convert_caps(results):
22 | fakecaps = []
23 | for result in results:
24 | image_id = result['question_id']
25 | caption = result['text']
26 | fakecaps.append({"image_id": int(image_id), "caption": caption})
27 | return fakecaps
28 |
29 |
30 | def get_pred_idx(prediction, choices, options):
31 | """
32 | Get the index (e.g. 2) from the prediction (e.g. 'C')
33 | """
34 | if prediction in options[:len(choices)]:
35 | return options.index(prediction)
36 | else:
37 | return random.choice(range(len(choices)))
38 |
39 |
40 | if __name__ == "__main__":
41 | args = get_args()
42 |
43 | base_dir = args.base_dir
44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split]
45 | problems = json.load(open(os.path.join(base_dir, "problems.json")))
46 | our_predictions = [json.loads(line) for line in open(args.our_result)]
47 | our_predictions = {pred['question_id']: pred for pred in our_predictions}
48 | split_problems = {idx: problems[idx] for idx in split_indices}
49 |
50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)]
51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions}
52 |
53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs']
54 |
55 | results = defaultdict(lambda: 0)
56 |
57 | sqa_results = {}
58 | sqa_results['acc'] = None
59 | sqa_results['correct'] = None
60 | sqa_results['count'] = None
61 | sqa_results['results'] = {}
62 | sqa_results['outputs'] = {}
63 |
64 | for prob_id, prob in split_problems.items():
65 | if prob_id not in our_predictions:
66 | assert False
67 | if prob_id not in gpt4_predictions:
68 | assert False
69 | our_pred = our_predictions[prob_id]['text']
70 | gpt4_pred = gpt4_predictions[prob_id]
71 | if prob_id not in requery_predictions:
72 | results['missing_requery'] += 1
73 | requery_pred = "MISSING"
74 | else:
75 | requery_pred = requery_predictions[prob_id]['text']
76 |
77 | pattern = re.compile(r'The answer is ([A-Z]).')
78 | our_res = pattern.findall(our_pred)
79 | if len(our_res) == 1:
80 | our_answer = our_res[0] # 'A', 'B', ...
81 | else:
82 | our_answer = "FAILED"
83 |
84 | requery_res = pattern.findall(requery_pred)
85 | if len(requery_res) == 1:
86 | requery_answer = requery_res[0] # 'A', 'B', ...
87 | else:
88 | requery_answer = "FAILED"
89 |
90 | gpt4_res = pattern.findall(gpt4_pred)
91 | if len(gpt4_res) == 1:
92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ...
93 | else:
94 | gpt4_answer = "FAILED"
95 |
96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options)
97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options)
98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options)
99 |
100 | results['total'] += 1
101 |
102 | if gpt4_answer == 'FAILED':
103 | results['gpt4_failed'] += 1
104 | if gpt4_pred_idx == prob['answer']:
105 | results['gpt4_correct'] += 1
106 | if our_pred_idx == prob['answer']:
107 | results['gpt4_ourvisual_correct'] += 1
108 | elif gpt4_pred_idx == prob['answer']:
109 | results['gpt4_correct'] += 1
110 | results['gpt4_ourvisual_correct'] += 1
111 |
112 | if our_pred_idx == prob['answer']:
113 | results['our_correct'] += 1
114 |
115 | if requery_answer == 'FAILED':
116 | sqa_results['results'][prob_id] = our_pred_idx
117 | if our_pred_idx == prob['answer']:
118 | results['requery_correct'] += 1
119 | else:
120 | sqa_results['results'][prob_id] = requery_pred_idx
121 | if requery_pred_idx == prob['answer']:
122 | results['requery_correct'] += 1
123 | else:
124 | print(f"""
125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']}
126 | Our ({our_answer}): {our_pred}
127 | GPT-4 ({gpt4_answer}): {gpt4_pred}
128 | Requery ({requery_answer}): {requery_pred}
129 | print("=====================================")
130 | """)
131 |
132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']:
133 | results['correct_upperbound'] += 1
134 |
135 | total = results['total']
136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%')
137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%')
138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%')
139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%')
140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%')
141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%')
142 |
143 | sqa_results['acc'] = results["requery_correct"] / total * 100
144 | sqa_results['correct'] = results["requery_correct"]
145 | sqa_results['count'] = total
146 |
147 | with open(args.output_result, 'w') as f:
148 | json.dump(sqa_results, f, indent=2)
149 |
150 |
--------------------------------------------------------------------------------
/llava/eval/eval_textvqa.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import json
4 | import re
5 |
6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator
7 |
8 |
9 | def get_args():
10 | parser = argparse.ArgumentParser()
11 | parser.add_argument('--annotation-file', type=str)
12 | parser.add_argument('--result-file', type=str)
13 | parser.add_argument('--result-dir', type=str)
14 | return parser.parse_args()
15 |
16 |
17 | def prompt_processor(prompt):
18 | if prompt.startswith('OCR tokens: '):
19 | pattern = r"Question: (.*?) Short answer:"
20 | match = re.search(pattern, prompt, re.DOTALL)
21 | question = match.group(1)
22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3:
23 | if prompt.startswith('Reference OCR token:'):
24 | question = prompt.split('\n')[1]
25 | else:
26 | question = prompt.split('\n')[0]
27 | elif len(prompt.split('\n')) == 2:
28 | question = prompt.split('\n')[0]
29 | else:
30 | assert False
31 |
32 | return question.lower()
33 |
34 |
35 | def eval_single(annotation_file, result_file):
36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0]
37 | print(experiment_name)
38 | annotations = json.load(open(annotation_file))['data']
39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations}
40 | results = [json.loads(line) for line in open(result_file)]
41 |
42 | pred_list = []
43 | for result in results:
44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))]
45 | pred_list.append({
46 | "pred_answer": result['text'],
47 | "gt_answers": annotation['answers'],
48 | })
49 |
50 | evaluator = TextVQAAccuracyEvaluator()
51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list)))
52 |
53 |
54 | if __name__ == "__main__":
55 | args = get_args()
56 |
57 | if args.result_file is not None:
58 | eval_single(args.annotation_file, args.result_file)
59 |
60 | if args.result_dir is not None:
61 | for result_file in sorted(os.listdir(args.result_dir)):
62 | if not result_file.endswith('.jsonl'):
63 | print(f'Skipping {result_file}')
64 | continue
65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file))
66 |
--------------------------------------------------------------------------------
/llava/eval/generate_webpage_data_from_table.py:
--------------------------------------------------------------------------------
1 | """Generate json file for webpage."""
2 | import json
3 | import os
4 | import re
5 |
6 | # models = ['llama', 'alpaca', 'gpt35', 'bard']
7 | models = ['vicuna']
8 |
9 |
10 | def read_jsonl(path: str, key: str=None):
11 | data = []
12 | with open(os.path.expanduser(path)) as f:
13 | for line in f:
14 | if not line:
15 | continue
16 | data.append(json.loads(line))
17 | if key is not None:
18 | data.sort(key=lambda x: x[key])
19 | data = {item[key]: item for item in data}
20 | return data
21 |
22 |
23 | def trim_hanging_lines(s: str, n: int) -> str:
24 | s = s.strip()
25 | for _ in range(n):
26 | s = s.split('\n', 1)[1].strip()
27 | return s
28 |
29 |
30 | if __name__ == '__main__':
31 | questions = read_jsonl('table/question.jsonl', key='question_id')
32 |
33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id')
34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id')
35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id')
36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id')
37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id')
38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id')
39 |
40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id')
41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id')
42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id')
43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id')
44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id')
45 |
46 | records = []
47 | for qid in questions.keys():
48 | r = {
49 | 'id': qid,
50 | 'category': questions[qid]['category'],
51 | 'question': questions[qid]['text'],
52 | 'answers': {
53 | # 'alpaca': alpaca_answers[qid]['text'],
54 | # 'llama': llama_answers[qid]['text'],
55 | # 'bard': bard_answers[qid]['text'],
56 | # 'gpt35': gpt35_answers[qid]['text'],
57 | 'vicuna': vicuna_answers[qid]['text'],
58 | 'ours': ours_answers[qid]['text'],
59 | },
60 | 'evaluations': {
61 | # 'alpaca': review_alpaca[qid]['text'],
62 | # 'llama': review_llama[qid]['text'],
63 | # 'bard': review_bard[qid]['text'],
64 | 'vicuna': review_vicuna[qid]['content'],
65 | # 'gpt35': review_gpt35[qid]['text'],
66 | },
67 | 'scores': {
68 | 'vicuna': review_vicuna[qid]['tuple'],
69 | # 'alpaca': review_alpaca[qid]['score'],
70 | # 'llama': review_llama[qid]['score'],
71 | # 'bard': review_bard[qid]['score'],
72 | # 'gpt35': review_gpt35[qid]['score'],
73 | },
74 | }
75 |
76 | # cleanup data
77 | cleaned_evals = {}
78 | for k, v in r['evaluations'].items():
79 | v = v.strip()
80 | lines = v.split('\n')
81 | # trim the first line if it's a pair of numbers
82 | if re.match(r'\d+[, ]+\d+', lines[0]):
83 | lines = lines[1:]
84 | v = '\n'.join(lines)
85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**')
86 |
87 | r['evaluations'] = cleaned_evals
88 | records.append(r)
89 |
90 | # Reorder the records, this is optional
91 | for r in records:
92 | if r['id'] <= 20:
93 | r['id'] += 60
94 | else:
95 | r['id'] -= 20
96 | for r in records:
97 | if r['id'] <= 50:
98 | r['id'] += 10
99 | elif 50 < r['id'] <= 60:
100 | r['id'] -= 50
101 | for r in records:
102 | if r['id'] == 7:
103 | r['id'] = 1
104 | elif r['id'] < 7:
105 | r['id'] += 1
106 |
107 | records.sort(key=lambda x: x['id'])
108 |
109 | # Write to file
110 | with open('webpage/data.json', 'w') as f:
111 | json.dump({'questions': records, 'models': models}, f, indent=2)
112 |
--------------------------------------------------------------------------------
/llava/eval/model_qa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria
3 | import torch
4 | import os
5 | import json
6 | from tqdm import tqdm
7 | import shortuuid
8 |
9 | from llava.conversation import default_conversation
10 | from llava.utils import disable_torch_init
11 |
12 |
13 | @torch.inference_mode()
14 | def eval_model(model_name, questions_file, answers_file):
15 | # Model
16 | disable_torch_init()
17 | model_name = os.path.expanduser(model_name)
18 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
19 | model = AutoModelForCausalLM.from_pretrained(model_name,
20 | torch_dtype=torch.float16).cuda()
21 |
22 |
23 | ques_file = open(os.path.expanduser(questions_file), "r")
24 | ans_file = open(os.path.expanduser(answers_file), "w")
25 | for i, line in enumerate(tqdm(ques_file)):
26 | idx = json.loads(line)["question_id"]
27 | qs = json.loads(line)["text"]
28 | cat = json.loads(line)["category"]
29 | conv = default_conversation.copy()
30 | conv.append_message(conv.roles[0], qs)
31 | prompt = conv.get_prompt()
32 | inputs = tokenizer([prompt])
33 | input_ids = torch.as_tensor(inputs.input_ids).cuda()
34 | output_ids = model.generate(
35 | input_ids,
36 | do_sample=True,
37 | use_cache=True,
38 | temperature=0.7,
39 | max_new_tokens=1024,)
40 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
41 | try:
42 | index = outputs.index(conv.sep, len(prompt))
43 | except ValueError:
44 | outputs += conv.sep
45 | index = outputs.index(conv.sep, len(prompt))
46 |
47 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip()
48 | ans_id = shortuuid.uuid()
49 | ans_file.write(json.dumps({"question_id": idx,
50 | "text": outputs,
51 | "answer_id": ans_id,
52 | "model_id": model_name,
53 | "metadata": {}}) + "\n")
54 | ans_file.flush()
55 | ans_file.close()
56 |
57 | if __name__ == "__main__":
58 | parser = argparse.ArgumentParser()
59 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
60 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
61 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
62 | args = parser.parse_args()
63 |
64 | eval_model(args.model_name, args.question_file, args.answers_file)
65 |
--------------------------------------------------------------------------------
/llava/eval/model_vqa.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import shortuuid
7 |
8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9 | from llava.conversation import conv_templates, SeparatorStyle
10 | from llava.model.builder import load_pretrained_model
11 | from llava.utils import disable_torch_init
12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13 |
14 | from PIL import Image
15 | import math
16 |
17 |
18 | def split_list(lst, n):
19 | """Split a list into n (roughly) equal-sized chunks"""
20 | chunk_size = math.ceil(len(lst) / n) # integer division
21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22 |
23 |
24 | def get_chunk(lst, n, k):
25 | chunks = split_list(lst, n)
26 | return chunks[k]
27 |
28 |
29 | def eval_model(args):
30 | # Model
31 | disable_torch_init()
32 | model_path = os.path.expanduser(args.model_path)
33 | model_name = get_model_name_from_path(model_path)
34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35 |
36 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
38 | answers_file = os.path.expanduser(args.answers_file)
39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
40 | ans_file = open(answers_file, "w")
41 | for line in tqdm(questions):
42 | idx = line["question_id"]
43 | image_file = line["image"]
44 | qs = line["text"]
45 | cur_prompt = qs
46 | if model.config.mm_use_im_start_end:
47 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
48 | else:
49 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
50 |
51 | conv = conv_templates[args.conv_mode].copy()
52 | conv.append_message(conv.roles[0], qs)
53 | conv.append_message(conv.roles[1], None)
54 | prompt = conv.get_prompt()
55 |
56 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
57 |
58 | image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB')
59 | image_tensor = process_images([image], image_processor, model.config)[0]
60 |
61 | with torch.inference_mode():
62 | output_ids = model.generate(
63 | input_ids,
64 | images=image_tensor.unsqueeze(0).half().cuda(),
65 | image_sizes=[image.size],
66 | do_sample=True if args.temperature > 0 else False,
67 | temperature=args.temperature,
68 | top_p=args.top_p,
69 | num_beams=args.num_beams,
70 | # no_repeat_ngram_size=3,
71 | max_new_tokens=1024,
72 | use_cache=True)
73 |
74 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
75 |
76 | ans_id = shortuuid.uuid()
77 | ans_file.write(json.dumps({"question_id": idx,
78 | "prompt": cur_prompt,
79 | "text": outputs,
80 | "answer_id": ans_id,
81 | "model_id": model_name,
82 | "metadata": {}}) + "\n")
83 | ans_file.flush()
84 | ans_file.close()
85 |
86 | if __name__ == "__main__":
87 | parser = argparse.ArgumentParser()
88 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
89 | parser.add_argument("--model-base", type=str, default=None)
90 | parser.add_argument("--image-folder", type=str, default="")
91 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
92 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
93 | parser.add_argument("--conv-mode", type=str, default="llava_v1")
94 | parser.add_argument("--num-chunks", type=int, default=1)
95 | parser.add_argument("--chunk-idx", type=int, default=0)
96 | parser.add_argument("--temperature", type=float, default=0.2)
97 | parser.add_argument("--top_p", type=float, default=None)
98 | parser.add_argument("--num_beams", type=int, default=1)
99 | args = parser.parse_args()
100 |
101 | eval_model(args)
102 |
--------------------------------------------------------------------------------
/llava/eval/model_vqa_loader.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import shortuuid
7 |
8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9 | from llava.conversation import conv_templates, SeparatorStyle
10 | from llava.model.builder import load_pretrained_model
11 | from llava.utils import disable_torch_init
12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13 | from torch.utils.data import Dataset, DataLoader
14 |
15 | from PIL import Image
16 | import math
17 |
18 |
19 | def split_list(lst, n):
20 | """Split a list into n (roughly) equal-sized chunks"""
21 | chunk_size = math.ceil(len(lst) / n) # integer division
22 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
23 |
24 |
25 | def get_chunk(lst, n, k):
26 | chunks = split_list(lst, n)
27 | return chunks[k]
28 |
29 |
30 | # Custom dataset class
31 | class CustomDataset(Dataset):
32 | def __init__(self, questions, image_folder, tokenizer, image_processor, model_config):
33 | self.questions = questions
34 | self.image_folder = image_folder
35 | self.tokenizer = tokenizer
36 | self.image_processor = image_processor
37 | self.model_config = model_config
38 |
39 | def __getitem__(self, index):
40 | line = self.questions[index]
41 | image_file = line["image"]
42 | qs = line["text"]
43 | if self.model_config.mm_use_im_start_end:
44 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
45 | else:
46 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
47 |
48 | conv = conv_templates[args.conv_mode].copy()
49 | conv.append_message(conv.roles[0], qs)
50 | conv.append_message(conv.roles[1], None)
51 | prompt = conv.get_prompt()
52 |
53 | image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB')
54 | image_tensor = process_images([image], self.image_processor, self.model_config)[0]
55 |
56 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt')
57 |
58 | return input_ids, image_tensor, image.size
59 |
60 | def __len__(self):
61 | return len(self.questions)
62 |
63 |
64 | def collate_fn(batch):
65 | input_ids, image_tensors, image_sizes = zip(*batch)
66 | input_ids = torch.stack(input_ids, dim=0)
67 | image_tensors = torch.stack(image_tensors, dim=0)
68 | return input_ids, image_tensors, image_sizes
69 |
70 |
71 | # DataLoader
72 | def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4):
73 | assert batch_size == 1, "batch_size must be 1"
74 | dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config)
75 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn)
76 | return data_loader
77 |
78 |
79 | def eval_model(args):
80 | # Model
81 | disable_torch_init()
82 | model_path = os.path.expanduser(args.model_path)
83 | model_name = get_model_name_from_path(model_path)
84 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
85 |
86 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")]
87 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
88 | answers_file = os.path.expanduser(args.answers_file)
89 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
90 | ans_file = open(answers_file, "w")
91 |
92 | if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
93 | args.conv_mode = args.conv_mode + '_mmtag'
94 | print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
95 |
96 | data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config)
97 |
98 | for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, questions), total=len(questions)):
99 | idx = line["question_id"]
100 | cur_prompt = line["text"]
101 |
102 | input_ids = input_ids.to(device='cuda', non_blocking=True)
103 |
104 | with torch.inference_mode():
105 | output_ids = model.generate(
106 | input_ids,
107 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True),
108 | image_sizes=image_sizes,
109 | do_sample=True if args.temperature > 0 else False,
110 | temperature=args.temperature,
111 | top_p=args.top_p,
112 | num_beams=args.num_beams,
113 | max_new_tokens=args.max_new_tokens,
114 | use_cache=True)
115 |
116 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
117 |
118 | ans_id = shortuuid.uuid()
119 | ans_file.write(json.dumps({"question_id": idx,
120 | "prompt": cur_prompt,
121 | "text": outputs,
122 | "answer_id": ans_id,
123 | "model_id": model_name,
124 | "metadata": {}}) + "\n")
125 | # ans_file.flush()
126 | ans_file.close()
127 |
128 | if __name__ == "__main__":
129 | parser = argparse.ArgumentParser()
130 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
131 | parser.add_argument("--model-base", type=str, default=None)
132 | parser.add_argument("--image-folder", type=str, default="")
133 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
134 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
135 | parser.add_argument("--conv-mode", type=str, default="llava_v1")
136 | parser.add_argument("--num-chunks", type=int, default=1)
137 | parser.add_argument("--chunk-idx", type=int, default=0)
138 | parser.add_argument("--temperature", type=float, default=0.2)
139 | parser.add_argument("--top_p", type=float, default=None)
140 | parser.add_argument("--num_beams", type=int, default=1)
141 | parser.add_argument("--max_new_tokens", type=int, default=128)
142 | args = parser.parse_args()
143 |
144 | eval_model(args)
145 |
--------------------------------------------------------------------------------
/llava/eval/model_vqa_mmbench.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | import pandas as pd
6 | from tqdm import tqdm
7 | import shortuuid
8 |
9 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
10 | from llava.conversation import conv_templates, SeparatorStyle
11 | from llava.model.builder import load_pretrained_model
12 | from llava.utils import disable_torch_init
13 | from llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path
14 |
15 | from PIL import Image
16 | import math
17 |
18 |
19 | all_options = ['A', 'B', 'C', 'D']
20 |
21 |
22 | def split_list(lst, n):
23 | """Split a list into n (roughly) equal-sized chunks"""
24 | chunk_size = math.ceil(len(lst) / n) # integer division
25 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
26 |
27 |
28 | def get_chunk(lst, n, k):
29 | chunks = split_list(lst, n)
30 | return chunks[k]
31 |
32 |
33 | def is_none(value):
34 | if value is None:
35 | return True
36 | if type(value) is float and math.isnan(value):
37 | return True
38 | if type(value) is str and value.lower() == 'nan':
39 | return True
40 | if type(value) is str and value.lower() == 'none':
41 | return True
42 | return False
43 |
44 | def get_options(row, options):
45 | parsed_options = []
46 | for option in options:
47 | option_value = row[option]
48 | if is_none(option_value):
49 | break
50 | parsed_options.append(option_value)
51 | return parsed_options
52 |
53 |
54 | def eval_model(args):
55 | # Model
56 | disable_torch_init()
57 | model_path = os.path.expanduser(args.model_path)
58 | model_name = get_model_name_from_path(model_path)
59 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
60 |
61 | questions = pd.read_table(os.path.expanduser(args.question_file))
62 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
63 | answers_file = os.path.expanduser(args.answers_file)
64 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
65 | ans_file = open(answers_file, "w")
66 |
67 | if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode:
68 | args.conv_mode = args.conv_mode + '_mmtag'
69 | print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.')
70 |
71 | for index, row in tqdm(questions.iterrows(), total=len(questions)):
72 | options = get_options(row, all_options)
73 | cur_option_char = all_options[:len(options)]
74 |
75 | if args.all_rounds:
76 | num_rounds = len(options)
77 | else:
78 | num_rounds = 1
79 |
80 | for round_idx in range(num_rounds):
81 | idx = row['index']
82 | question = row['question']
83 | hint = row['hint']
84 | image = load_image_from_base64(row['image'])
85 | if not is_none(hint):
86 | question = hint + '\n' + question
87 | for option_char, option in zip(all_options[:len(options)], options):
88 | question = question + '\n' + option_char + '. ' + option
89 | qs = cur_prompt = question
90 | if model.config.mm_use_im_start_end:
91 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
92 | else:
93 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
94 |
95 | if args.single_pred_prompt:
96 | if args.lang == 'cn':
97 | qs = qs + '\n' + "请直接回答选项字母。"
98 | else:
99 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
100 |
101 | conv = conv_templates[args.conv_mode].copy()
102 | conv.append_message(conv.roles[0], qs)
103 | conv.append_message(conv.roles[1], None)
104 | prompt = conv.get_prompt()
105 |
106 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
107 |
108 | image_tensor = process_images([image], image_processor, model.config)[0]
109 |
110 | with torch.inference_mode():
111 | output_ids = model.generate(
112 | input_ids,
113 | images=image_tensor.unsqueeze(0).half().cuda(),
114 | image_sizes=[image.size],
115 | do_sample=True if args.temperature > 0 else False,
116 | temperature=args.temperature,
117 | top_p=args.top_p,
118 | num_beams=args.num_beams,
119 | # no_repeat_ngram_size=3,
120 | max_new_tokens=1024,
121 | use_cache=True)
122 |
123 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
124 |
125 | ans_id = shortuuid.uuid()
126 | ans_file.write(json.dumps({"question_id": idx,
127 | "round_id": round_idx,
128 | "prompt": cur_prompt,
129 | "text": outputs,
130 | "options": options,
131 | "option_char": cur_option_char,
132 | "answer_id": ans_id,
133 | "model_id": model_name,
134 | "metadata": {}}) + "\n")
135 | ans_file.flush()
136 |
137 | # rotate options
138 | options = options[1:] + options[:1]
139 | cur_option_char = cur_option_char[1:] + cur_option_char[:1]
140 | ans_file.close()
141 |
142 | if __name__ == "__main__":
143 | parser = argparse.ArgumentParser()
144 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
145 | parser.add_argument("--model-base", type=str, default=None)
146 | parser.add_argument("--image-folder", type=str, default="")
147 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl")
148 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
149 | parser.add_argument("--conv-mode", type=str, default="llava_v1")
150 | parser.add_argument("--num-chunks", type=int, default=1)
151 | parser.add_argument("--chunk-idx", type=int, default=0)
152 | parser.add_argument("--temperature", type=float, default=0.2)
153 | parser.add_argument("--top_p", type=float, default=None)
154 | parser.add_argument("--num_beams", type=int, default=1)
155 | parser.add_argument("--all-rounds", action="store_true")
156 | parser.add_argument("--single-pred-prompt", action="store_true")
157 | parser.add_argument("--lang", type=str, default="en")
158 | args = parser.parse_args()
159 |
160 | eval_model(args)
161 |
--------------------------------------------------------------------------------
/llava/eval/model_vqa_science.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import json
5 | from tqdm import tqdm
6 | import shortuuid
7 |
8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
9 | from llava.conversation import conv_templates, SeparatorStyle
10 | from llava.model.builder import load_pretrained_model
11 | from llava.utils import disable_torch_init
12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path
13 |
14 | from PIL import Image
15 | import math
16 |
17 |
18 | def split_list(lst, n):
19 | """Split a list into n (roughly) equal-sized chunks"""
20 | chunk_size = math.ceil(len(lst) / n) # integer division
21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)]
22 |
23 |
24 | def get_chunk(lst, n, k):
25 | chunks = split_list(lst, n)
26 | return chunks[k]
27 |
28 |
29 | def eval_model(args):
30 | # Model
31 | disable_torch_init()
32 | model_path = os.path.expanduser(args.model_path)
33 | model_name = get_model_name_from_path(model_path)
34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name)
35 |
36 | questions = json.load(open(os.path.expanduser(args.question_file), "r"))
37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx)
38 | answers_file = os.path.expanduser(args.answers_file)
39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True)
40 | ans_file = open(answers_file, "w")
41 | for i, line in enumerate(tqdm(questions)):
42 | idx = line["id"]
43 | question = line['conversations'][0]
44 | qs = question['value'].replace('', '').strip()
45 | cur_prompt = qs
46 |
47 | if 'image' in line:
48 | image_file = line["image"]
49 | image = Image.open(os.path.join(args.image_folder, image_file))
50 | image_tensor = process_images([image], image_processor, model.config)[0]
51 | images = image_tensor.unsqueeze(0).half().cuda()
52 | image_sizes = [image.size]
53 | if getattr(model.config, 'mm_use_im_start_end', False):
54 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs
55 | else:
56 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
57 | cur_prompt = '' + '\n' + cur_prompt
58 | else:
59 | images = None
60 | image_sizes = None
61 |
62 | if args.single_pred_prompt:
63 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly."
64 | cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly."
65 |
66 | conv = conv_templates[args.conv_mode].copy()
67 | conv.append_message(conv.roles[0], qs)
68 | conv.append_message(conv.roles[1], None)
69 | prompt = conv.get_prompt()
70 |
71 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()
72 |
73 | with torch.inference_mode():
74 | output_ids = model.generate(
75 | input_ids,
76 | images=images,
77 | image_sizes=image_sizes,
78 | do_sample=True if args.temperature > 0 else False,
79 | temperature=args.temperature,
80 | max_new_tokens=1024,
81 | use_cache=True,
82 | )
83 |
84 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
85 |
86 | ans_id = shortuuid.uuid()
87 | ans_file.write(json.dumps({"question_id": idx,
88 | "prompt": cur_prompt,
89 | "text": outputs,
90 | "answer_id": ans_id,
91 | "model_id": model_name,
92 | "metadata": {}}) + "\n")
93 | ans_file.flush()
94 | ans_file.close()
95 |
96 | if __name__ == "__main__":
97 | parser = argparse.ArgumentParser()
98 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
99 | parser.add_argument("--model-base", type=str, default=None)
100 | parser.add_argument("--image-folder", type=str, default="")
101 | parser.add_argument("--question-file", type=str, default="tables/question.json")
102 | parser.add_argument("--answers-file", type=str, default="answer.jsonl")
103 | parser.add_argument("--conv-mode", type=str, default="llava_v0")
104 | parser.add_argument("--num-chunks", type=int, default=1)
105 | parser.add_argument("--chunk-idx", type=int, default=0)
106 | parser.add_argument("--temperature", type=float, default=0.2)
107 | parser.add_argument("--answer-prompter", action="store_true")
108 | parser.add_argument("--single-pred-prompt", action="store_true")
109 | args = parser.parse_args()
110 |
111 | eval_model(args)
112 |
--------------------------------------------------------------------------------
/llava/eval/qa_baseline_gpt35.py:
--------------------------------------------------------------------------------
1 | """Generate answers with GPT-3.5"""
2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work
3 | import argparse
4 | import json
5 | import os
6 | import time
7 | import concurrent.futures
8 |
9 | import openai
10 | import tqdm
11 | import shortuuid
12 |
13 | MODEL = 'gpt-3.5-turbo'
14 | MODEL_ID = 'gpt-3.5-turbo:20230327'
15 |
16 | def get_answer(question_id: int, question: str, max_tokens: int):
17 | ans = {
18 | 'answer_id': shortuuid.uuid(),
19 | 'question_id': question_id,
20 | 'model_id': MODEL_ID,
21 | }
22 | for _ in range(3):
23 | try:
24 | response = openai.ChatCompletion.create(
25 | model=MODEL,
26 | messages=[{
27 | 'role': 'system',
28 | 'content': 'You are a helpful assistant.'
29 | }, {
30 | 'role': 'user',
31 | 'content': question,
32 | }],
33 | max_tokens=max_tokens,
34 | )
35 | ans['text'] = response['choices'][0]['message']['content']
36 | return ans
37 | except Exception as e:
38 | print('[ERROR]', e)
39 | ans['text'] = '#ERROR#'
40 | time.sleep(1)
41 | return ans
42 |
43 |
44 | if __name__ == '__main__':
45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.')
46 | parser.add_argument('-q', '--question')
47 | parser.add_argument('-o', '--output')
48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output')
49 | args = parser.parse_args()
50 |
51 | questions_dict = {}
52 | with open(os.path.expanduser(args.question)) as f:
53 | for line in f:
54 | if not line:
55 | continue
56 | q = json.loads(line)
57 | questions_dict[q['question_id']] = q['text']
58 |
59 | answers = []
60 |
61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
62 | futures = []
63 | for qid, question in questions_dict.items():
64 | future = executor.submit(get_answer, qid, question, args.max_tokens)
65 | futures.append(future)
66 |
67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
68 | answers.append(future.result())
69 |
70 | answers.sort(key=lambda x: x['question_id'])
71 |
72 | with open(os.path.expanduser(args.output), 'w') as f:
73 | table = [json.dumps(ans) for ans in answers]
74 | f.write('\n'.join(table))
75 |
--------------------------------------------------------------------------------
/llava/eval/run_llava.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from llava.constants import (
5 | IMAGE_TOKEN_INDEX,
6 | DEFAULT_IMAGE_TOKEN,
7 | DEFAULT_IM_START_TOKEN,
8 | DEFAULT_IM_END_TOKEN,
9 | IMAGE_PLACEHOLDER,
10 | )
11 | from llava.conversation import conv_templates, SeparatorStyle
12 | from llava.model.builder import load_pretrained_model
13 | from llava.utils import disable_torch_init
14 | from llava.mm_utils import (
15 | process_images,
16 | tokenizer_image_token,
17 | get_model_name_from_path,
18 | )
19 |
20 | from PIL import Image
21 |
22 | import requests
23 | from PIL import Image
24 | from io import BytesIO
25 | import re
26 |
27 |
28 | def image_parser(args):
29 | out = args.image_file.split(args.sep)
30 | return out
31 |
32 |
33 | def load_image(image_file):
34 | if image_file.startswith("http") or image_file.startswith("https"):
35 | response = requests.get(image_file)
36 | image = Image.open(BytesIO(response.content)).convert("RGB")
37 | else:
38 | image = Image.open(image_file).convert("RGB")
39 | return image
40 |
41 |
42 | def load_images(image_files):
43 | out = []
44 | for image_file in image_files:
45 | image = load_image(image_file)
46 | out.append(image)
47 | return out
48 |
49 |
50 | def eval_model(args):
51 | # Model
52 | disable_torch_init()
53 |
54 | model_name = get_model_name_from_path(args.model_path)
55 | tokenizer, model, image_processor, context_len = load_pretrained_model(
56 | args.model_path, args.model_base, model_name
57 | )
58 |
59 | qs = args.query
60 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
61 | if IMAGE_PLACEHOLDER in qs:
62 | if model.config.mm_use_im_start_end:
63 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
64 | else:
65 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
66 | else:
67 | if model.config.mm_use_im_start_end:
68 | qs = image_token_se + "\n" + qs
69 | else:
70 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
71 |
72 | if "llama-2" in model_name.lower():
73 | conv_mode = "llava_llama_2"
74 | elif "mistral" in model_name.lower():
75 | conv_mode = "mistral_instruct"
76 | elif "v1.6-34b" in model_name.lower():
77 | conv_mode = "chatml_direct"
78 | elif "v1" in model_name.lower():
79 | conv_mode = "llava_v1"
80 | elif "mpt" in model_name.lower():
81 | conv_mode = "mpt"
82 | else:
83 | conv_mode = "llava_v0"
84 |
85 | if args.conv_mode is not None and conv_mode != args.conv_mode:
86 | print(
87 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
88 | conv_mode, args.conv_mode, args.conv_mode
89 | )
90 | )
91 | else:
92 | args.conv_mode = conv_mode
93 |
94 | conv = conv_templates[args.conv_mode].copy()
95 | conv.append_message(conv.roles[0], qs)
96 | conv.append_message(conv.roles[1], None)
97 | prompt = conv.get_prompt()
98 |
99 | image_files = image_parser(args)
100 | images = load_images(image_files)
101 | image_sizes = [x.size for x in images]
102 | images_tensor = process_images(
103 | images,
104 | image_processor,
105 | model.config
106 | ).to(model.device, dtype=torch.float16)
107 |
108 | input_ids = (
109 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
110 | .unsqueeze(0)
111 | .cuda()
112 | )
113 |
114 | with torch.inference_mode():
115 | output_ids = model.generate(
116 | input_ids,
117 | images=images_tensor,
118 | image_sizes=image_sizes,
119 | do_sample=True if args.temperature > 0 else False,
120 | temperature=args.temperature,
121 | top_p=args.top_p,
122 | num_beams=args.num_beams,
123 | max_new_tokens=args.max_new_tokens,
124 | use_cache=True,
125 | )
126 |
127 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
128 | print(outputs)
129 |
130 |
131 | if __name__ == "__main__":
132 | parser = argparse.ArgumentParser()
133 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
134 | parser.add_argument("--model-base", type=str, default=None)
135 | parser.add_argument("--image-file", type=str, required=True)
136 | parser.add_argument("--query", type=str, required=True)
137 | parser.add_argument("--conv-mode", type=str, default=None)
138 | parser.add_argument("--sep", type=str, default=",")
139 | parser.add_argument("--temperature", type=float, default=0.2)
140 | parser.add_argument("--top_p", type=float, default=None)
141 | parser.add_argument("--num_beams", type=int, default=1)
142 | parser.add_argument("--max_new_tokens", type=int, default=512)
143 | args = parser.parse_args()
144 |
145 | eval_model(args)
146 |
--------------------------------------------------------------------------------
/llava/eval/summarize_gpt_review.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from collections import defaultdict
4 |
5 | import numpy as np
6 |
7 | import argparse
8 |
9 | def parse_args():
10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.')
11 | parser.add_argument('-d', '--dir', default=None)
12 | parser.add_argument('-v', '--version', default=None)
13 | parser.add_argument('-s', '--select', nargs='*', default=None)
14 | parser.add_argument('-f', '--files', nargs='*', default=[])
15 | parser.add_argument('-i', '--ignore', nargs='*', default=[])
16 | return parser.parse_args()
17 |
18 |
19 | if __name__ == '__main__':
20 | args = parse_args()
21 |
22 | if args.ignore is not None:
23 | args.ignore = [int(x) for x in args.ignore]
24 |
25 | if len(args.files) > 0:
26 | review_files = args.files
27 | else:
28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)]
29 |
30 | for review_file in sorted(review_files):
31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '')
32 | if args.select is not None and any(x not in config for x in args.select):
33 | continue
34 | if '0613' in config:
35 | version = '0613'
36 | else:
37 | version = '0314'
38 | if args.version is not None and args.version != version:
39 | continue
40 | scores = defaultdict(list)
41 | print(config)
42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f:
43 | for review_str in f:
44 | review = json.loads(review_str)
45 | if review['question_id'] in args.ignore:
46 | continue
47 | if 'category' in review:
48 | scores[review['category']].append(review['tuple'])
49 | scores['all'].append(review['tuple'])
50 | else:
51 | if 'tuple' in review:
52 | scores['all'].append(review['tuple'])
53 | else:
54 | scores['all'].append(review['score'])
55 | for k, v in sorted(scores.items()):
56 | stats = np.asarray(v).mean(0).tolist()
57 | stats = [round(x, 3) for x in stats]
58 | # print(k, stats, round(stats[1]/stats[0]*100, 1))
59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1))
60 | print('=================================')
61 |
--------------------------------------------------------------------------------
/llava/eval/table/model.jsonl:
--------------------------------------------------------------------------------
1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"}
2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"}
3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"}
4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"}
5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"}
6 |
--------------------------------------------------------------------------------
/llava/eval/table/prompt.jsonl:
--------------------------------------------------------------------------------
1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"}
2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"}
3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"}
4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"}
5 |
--------------------------------------------------------------------------------
/llava/eval/table/reviewer.jsonl:
--------------------------------------------------------------------------------
1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"}
2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"}
3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"}
5 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/alpaca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/eval/webpage/figures/alpaca.png
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/bard.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/eval/webpage/figures/bard.jpg
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/chatgpt.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/llama.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/eval/webpage/figures/llama.jpg
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/llava/eval/webpage/figures/vicuna.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/eval/webpage/figures/vicuna.jpeg
--------------------------------------------------------------------------------
/llava/eval/webpage/styles.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
3 | background-color: #f8f9fa;
4 | }
5 |
6 | .navbar-dark .navbar-nav .nav-link {
7 | color: #f1cf68;
8 | font-size: 1.1rem;
9 | padding: 0.5rem 0.6rem;
10 | }
11 |
12 | .card-header {
13 | font-weight: bold;
14 | }
15 |
16 | .card {
17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1);
18 | transition: 0.3s;
19 | }
20 |
21 | .card:hover {
22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2);
23 | }
24 |
25 | button {
26 | transition: background-color 0.3s;
27 | }
28 |
29 | button:hover {
30 | background-color: #007bff;
31 | }
32 |
33 | @media (max-width: 767px) {
34 | .form-row .form-group {
35 | margin-bottom: 10px;
36 | }
37 | }
38 |
39 | /* Extra styles */
40 |
41 | .expandable-card .card-text-container {
42 | max-height: 200px;
43 | overflow-y: hidden;
44 | position: relative;
45 | }
46 |
47 | .expandable-card.expanded .card-text-container {
48 | max-height: none;
49 | }
50 |
51 | .expand-btn {
52 | position: relative;
53 | display: none;
54 | background-color: rgba(255, 255, 255, 0.8);
55 | color: #510c75;
56 | border-color: transparent;
57 | }
58 |
59 | .expand-btn:hover {
60 | background-color: rgba(200, 200, 200, 0.8);
61 | text-decoration: none;
62 | border-color: transparent;
63 | color: #510c75;
64 | }
65 |
66 | .expand-btn:focus {
67 | outline: none;
68 | text-decoration: none;
69 | }
70 |
71 | .expandable-card:not(.expanded) .card-text-container:after {
72 | content: "";
73 | position: absolute;
74 | bottom: 0;
75 | left: 0;
76 | width: 100%;
77 | height: 90px;
78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1));
79 | }
80 |
81 | .expandable-card:not(.expanded) .expand-btn {
82 | margin-top: -40px;
83 | }
84 |
85 | .card-body {
86 | padding-bottom: 5px;
87 | }
88 |
89 | .vertical-flex-layout {
90 | justify-content: center;
91 | align-items: center;
92 | height: 100%;
93 | display: flex;
94 | flex-direction: column;
95 | gap: 5px;
96 | }
97 |
98 | .figure-img {
99 | max-width: 100%;
100 | height: auto;
101 | }
102 |
103 | .adjustable-font-size {
104 | font-size: calc(0.5rem + 2vw);
105 | }
106 |
--------------------------------------------------------------------------------
/llava/garmentcodeRC_utils.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | import yaml
4 | from pathlib import Path
5 | from collections import OrderedDict
6 | import pickle as pkl
7 | import argparse
8 | import json
9 | import re
10 | import copy
11 | import torch
12 | import numpy as np
13 |
14 | wb_config_name = 'waistband'
15 | skirt_configs = {
16 | 'SkirtCircle': 'flare-skirt',
17 | 'AsymmSkirtCircle': 'flare-skirt',
18 | 'GodetSkirt': 'godet-skirt',
19 | 'Pants': 'pants',
20 | 'Skirt2': 'skirt',
21 | 'SkirtManyPanels': 'flare-skirt',
22 | 'PencilSkirt': 'pencil-skirt',
23 | 'SkirtLevels': 'levels-skirt',
24 | }
25 | all_skirt_configs = ['skirt', 'flare-skirt', 'godet-skirt', 'pencil-skirt', 'levels-skirt', 'pants']
26 |
27 |
28 |
29 | def ordered(d, desired_key_order):
30 | return OrderedDict([(key, d[key]) for key in desired_key_order])
31 |
32 |
33 | def recursive_simplify_params(cfg, is_used=True, unused_configs=[], parent_path='design', device='cpu'):
34 | # change float to 4 decimal places
35 | if cfg is None:
36 | print(parent_path)
37 |
38 | cfg_new = {}
39 | if ('type' not in cfg) or not isinstance(cfg['type'], str):
40 |
41 | if 'enable_asym' in cfg: ############################################
42 | enable_asym = bool(cfg['enable_asym']['v'])
43 | if not enable_asym:
44 | cfg_new['enable_asym'] = cfg['enable_asym']['v']
45 | return cfg_new
46 |
47 | if parent_path == 'design.sleeve.cuff' and cfg['type']['v'] is None:
48 | return {'type': None}
49 |
50 | if parent_path == 'design.left.sleeve.cuff' and cfg['type']['v'] is None:
51 | return {'type': None}
52 |
53 | if parent_path == 'design.pants.cuff' and cfg['type']['v'] is None:
54 | return {'type': None}
55 |
56 | # if parent_path == 'design.sleeve' and cfg['sleeveless']['v']:
57 | # return {'type': None}
58 |
59 | # if parent_path == 'design.sleeve'
60 |
61 | for subpattern_n, subpattern_cfg in cfg.items():
62 | if (subpattern_n in unused_configs) and ('meta' in cfg):
63 | continue
64 | else:
65 | subconfig = recursive_simplify_params(subpattern_cfg, is_used=is_used, parent_path=parent_path + '.' + subpattern_n, device=device)
66 |
67 | cfg_new[subpattern_n] = subconfig
68 |
69 | else:
70 | type_now = cfg['type']
71 | if type_now == 'float':
72 | lower_bd = float(cfg['range'][0])
73 | upper_bd = float(cfg['range'][1])
74 |
75 | float_val = cfg['v']
76 | float_val_normed = (float_val - lower_bd) / (upper_bd - lower_bd)
77 | cfg_new = torch.tensor([float_val_normed]).float().to(device)
78 |
79 | else:
80 | cfg_new = cfg['v']
81 |
82 | return cfg_new
83 |
84 |
85 | def GarmentCodeRC_simplify_params(new_config, device='cpu'):
86 | if 'design' in new_config:
87 | new_config = new_config['design']
88 |
89 | ################ get unused_configs
90 | unused_configs = []
91 | ub_garment = new_config['meta']['upper']['v']
92 | if ub_garment is None:
93 | unused_configs += ['shirt', 'collar', 'sleeve', 'left']
94 |
95 | wb_garment = new_config['meta']['wb']['v']
96 | if not wb_garment:
97 | unused_configs.append(wb_config_name)
98 |
99 | lower_garment = new_config['meta']['bottom']['v']
100 | assert lower_garment != 'null', (lower_garment)
101 | if lower_garment is None:
102 | unused_configs += all_skirt_configs
103 | else:
104 | unused_configs += copy.deepcopy(all_skirt_configs)
105 | unused_configs.remove(skirt_configs[lower_garment])
106 |
107 | if 'base' in new_config[skirt_configs[lower_garment]]:
108 | base_garment = new_config[skirt_configs[lower_garment]]['base']['v']
109 | unused_configs.remove(skirt_configs[base_garment])
110 |
111 | new_config = recursive_simplify_params(new_config, is_used=True, unused_configs=unused_configs, device=device)
112 |
113 | return new_config
114 |
115 |
116 | def update_design_ranges():
117 | return
--------------------------------------------------------------------------------
/llava/lisa_utils.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | import numpy as np
4 | import torch
5 | import torch.distributed as dist
6 |
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 |
14 | SHORT_QUESTION_LIST = [
15 | DEFAULT_IMAGE_TOKEN + "\n" + "Can you predict the SMPL pose of the person in this image?",
16 | DEFAULT_IMAGE_TOKEN + "\n" + "There is person in the middle of the image, please output this person's SMPL pose.",
17 | DEFAULT_IMAGE_TOKEN
18 | + "\n"
19 | + "What is the human pose in this image? Please respond with SMPL pose.",
20 | DEFAULT_IMAGE_TOKEN
21 | + "\n"
22 | + "What is the person doing in this image? Please output SMPL pose.",
23 | DEFAULT_IMAGE_TOKEN + "\n" + "There is person in the middle of the image, please output this person's SMPL pose.",
24 | ]
25 |
26 | LONG_QUESTION_LIST = [
27 | DEFAULT_IMAGE_TOKEN + "\n" + "{sent} Please respond with SMPL pose.",
28 | DEFAULT_IMAGE_TOKEN + "\n" + "{sent} Please output SMPL pose.",
29 | ]
30 |
31 | EXPLANATORY_QUESTION_LIST = [
32 | "Please output SMPL pose and explain the pose.",
33 | "Please output SMPL pose and explain the reason.",
34 | "Please output SMPL pose and give some explanation.",
35 | ]
36 |
37 | ANSWER_LIST = [
38 | "It is [SEG].",
39 | "Sure, [SEG].",
40 | "Sure, it is [SEG].",
41 | "Sure, the SMPL pose is [SEG].",
42 | "[SEG].",
43 | ]
44 |
45 |
46 | class Summary(Enum):
47 | NONE = 0
48 | AVERAGE = 1
49 | SUM = 2
50 | COUNT = 3
51 |
52 |
53 | class AverageMeter(object):
54 | """Computes and stores the average and current value"""
55 |
56 | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
57 | self.name = name
58 | self.fmt = fmt
59 | self.summary_type = summary_type
60 | self.reset()
61 |
62 | def reset(self):
63 | self.val = 0
64 | self.avg = 0
65 | self.sum = 0
66 | self.count = 0
67 |
68 | def update(self, val, n=1):
69 | self.val = val
70 | self.sum += val * n
71 | self.count += n
72 | self.avg = self.sum / self.count
73 |
74 | def all_reduce(self):
75 | device = "cuda" if torch.cuda.is_available() else "cpu"
76 | if isinstance(self.sum, np.ndarray):
77 | total = torch.tensor(
78 | self.sum.tolist()
79 | + [
80 | self.count,
81 | ],
82 | dtype=torch.float32,
83 | device=device,
84 | )
85 | else:
86 | total = torch.tensor(
87 | [self.sum, self.count], dtype=torch.float32, device=device
88 | )
89 |
90 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
91 | if total.shape[0] > 2:
92 | self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item()
93 | else:
94 | self.sum, self.count = total.tolist()
95 | self.avg = self.sum / (self.count + 1e-5)
96 |
97 | def __str__(self):
98 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
99 | return fmtstr.format(**self.__dict__)
100 |
101 | def summary(self):
102 | fmtstr = ""
103 | if self.summary_type is Summary.NONE:
104 | fmtstr = ""
105 | elif self.summary_type is Summary.AVERAGE:
106 | fmtstr = "{name} {avg:.3f}"
107 | elif self.summary_type is Summary.SUM:
108 | fmtstr = "{name} {sum:.3f}"
109 | elif self.summary_type is Summary.COUNT:
110 | fmtstr = "{name} {count:.3f}"
111 | else:
112 | raise ValueError("invalid summary type %r" % self.summary_type)
113 |
114 | return fmtstr.format(**self.__dict__)
115 |
116 |
117 | def intersectionAndUnionGPU(output, target, K, ignore_index=255):
118 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
119 | assert output.dim() in [1, 2, 3]
120 | assert output.shape == target.shape
121 | output = output.view(-1)
122 | target = target.view(-1)
123 | output[target == ignore_index] = ignore_index
124 | intersection = output[output == target]
125 | area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1)
126 | area_output = torch.histc(output, bins=K, min=0, max=K - 1)
127 | area_target = torch.histc(target, bins=K, min=0, max=K - 1)
128 | area_union = area_output + area_target - area_intersection
129 | return area_intersection, area_union, area_target
130 |
131 |
132 | class ProgressMeter(object):
133 | def __init__(self, num_batches, meters, prefix=""):
134 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
135 | self.meters = meters
136 | self.prefix = prefix
137 |
138 | def display(self, batch):
139 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
140 | entries += [str(meter) for meter in self.meters]
141 | print("\t".join(entries))
142 |
143 | def display_summary(self):
144 | entries = [" *"]
145 | entries += [meter.summary() for meter in self.meters]
146 | print(" ".join(entries))
147 |
148 | def _get_batch_fmtstr(self, num_batches):
149 | num_digits = len(str(num_batches // 1))
150 | fmt = "{:" + str(num_digits) + "d}"
151 | return "[" + fmt + "/" + fmt.format(num_batches) + "]"
152 |
153 |
154 | def dict_to_cuda(input_dict):
155 | for k, v in input_dict.items():
156 | if isinstance(input_dict[k], torch.Tensor):
157 | input_dict[k] = v.cuda(non_blocking=True)
158 | elif (
159 | isinstance(input_dict[k], list)
160 | and len(input_dict[k]) > 0
161 | and isinstance(input_dict[k][0], torch.Tensor)
162 | ):
163 | input_dict[k] = [ele.cuda(non_blocking=True) for ele in v]
164 | return input_dict
165 |
--------------------------------------------------------------------------------
/llava/model/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig
3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig
4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig
5 | from .language_model.llava_garment_float50 import GarmentGPTFloat50ForCausalLM
6 | except:
7 | pass
8 |
--------------------------------------------------------------------------------
/llava/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava import LlavaLlamaForCausalLM
11 |
12 |
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32 | bparam = base.state_dict()[name]
33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34 |
35 | print("Saving target model")
36 | delta.save_pretrained(target_model_path)
37 | delta_tokenizer.save_pretrained(target_model_path)
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--base-model-path", type=str, required=True)
43 | parser.add_argument("--target-model-path", type=str, required=True)
44 | parser.add_argument("--delta-path", type=str, required=True)
45 |
46 | args = parser.parse_args()
47 |
48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
49 |
--------------------------------------------------------------------------------
/llava/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | from llava.model import *
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def consolidate_ckpt(src_path, dst_path):
14 | print("Loading model")
15 | auto_upgrade(src_path)
16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18 | src_model.save_pretrained(dst_path)
19 | src_tokenizer.save_pretrained(dst_path)
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--src", type=str, required=True)
25 | parser.add_argument("--dst", type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 | consolidate_ckpt(args.src, args.dst)
30 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_llama.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 |
21 | from transformers import AutoConfig, AutoModelForCausalLM, \
22 | LlamaConfig, LlamaModel, LlamaForCausalLM
23 |
24 | from transformers.modeling_outputs import CausalLMOutputWithPast
25 | from transformers.generation.utils import GenerateOutput
26 |
27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
28 |
29 |
30 | class LlavaConfig(LlamaConfig):
31 | model_type = "llava_llama"
32 |
33 |
34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel):
35 | config_class = LlavaConfig
36 |
37 | def __init__(self, config: LlamaConfig):
38 | super(LlavaLlamaModel, self).__init__(config)
39 |
40 |
41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM):
42 | config_class = LlavaConfig
43 |
44 | def __init__(self, config):
45 | super(LlamaForCausalLM, self).__init__(config)
46 | self.model = LlavaLlamaModel(config)
47 | self.pretraining_tp = config.pretraining_tp
48 | self.vocab_size = config.vocab_size
49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50 |
51 | # Initialize weights and apply final processing
52 | self.post_init()
53 |
54 | def get_model(self):
55 | return self.model
56 |
57 | def forward(
58 | self,
59 | input_ids: torch.LongTensor = None,
60 | attention_mask: Optional[torch.Tensor] = None,
61 | position_ids: Optional[torch.LongTensor] = None,
62 | past_key_values: Optional[List[torch.FloatTensor]] = None,
63 | inputs_embeds: Optional[torch.FloatTensor] = None,
64 | labels: Optional[torch.LongTensor] = None,
65 | use_cache: Optional[bool] = None,
66 | output_attentions: Optional[bool] = None,
67 | output_hidden_states: Optional[bool] = None,
68 | images: Optional[torch.FloatTensor] = None,
69 | image_sizes: Optional[List[List[int]]] = None,
70 | return_dict: Optional[bool] = None,
71 | ) -> Union[Tuple, CausalLMOutputWithPast]:
72 |
73 | if inputs_embeds is None:
74 | (
75 | input_ids,
76 | position_ids,
77 | attention_mask,
78 | past_key_values,
79 | inputs_embeds,
80 | labels
81 | ) = self.prepare_inputs_labels_for_multimodal(
82 | input_ids,
83 | position_ids,
84 | attention_mask,
85 | past_key_values,
86 | labels,
87 | images,
88 | image_sizes
89 | )
90 |
91 | return super().forward(
92 | input_ids=input_ids,
93 | attention_mask=attention_mask,
94 | position_ids=position_ids,
95 | past_key_values=past_key_values,
96 | inputs_embeds=inputs_embeds,
97 | labels=labels,
98 | use_cache=use_cache,
99 | output_attentions=output_attentions,
100 | output_hidden_states=output_hidden_states,
101 | return_dict=return_dict
102 | )
103 |
104 | @torch.no_grad()
105 | def generate(
106 | self,
107 | inputs: Optional[torch.Tensor] = None,
108 | images: Optional[torch.Tensor] = None,
109 | image_sizes: Optional[torch.Tensor] = None,
110 | **kwargs,
111 | ) -> Union[GenerateOutput, torch.LongTensor]:
112 | position_ids = kwargs.pop("position_ids", None)
113 | attention_mask = kwargs.pop("attention_mask", None)
114 | if "inputs_embeds" in kwargs:
115 | raise NotImplementedError("`inputs_embeds` is not supported")
116 |
117 | if images is not None:
118 | (
119 | inputs,
120 | position_ids,
121 | attention_mask,
122 | _,
123 | inputs_embeds,
124 | _
125 | ) = self.prepare_inputs_labels_for_multimodal(
126 | inputs,
127 | position_ids,
128 | attention_mask,
129 | None,
130 | None,
131 | images,
132 | image_sizes=image_sizes
133 | )
134 | else:
135 | inputs_embeds = self.get_model().embed_tokens(inputs)
136 |
137 | return super().generate(
138 | position_ids=position_ids,
139 | attention_mask=attention_mask,
140 | inputs_embeds=inputs_embeds,
141 | **kwargs
142 | )
143 |
144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145 | inputs_embeds=None, **kwargs):
146 | images = kwargs.pop("images", None)
147 | image_sizes = kwargs.pop("image_sizes", None)
148 | inputs = super().prepare_inputs_for_generation(
149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150 | )
151 | if images is not None:
152 | inputs['images'] = images
153 | if image_sizes is not None:
154 | inputs['image_sizes'] = image_sizes
155 | return inputs
156 |
157 | AutoConfig.register("llava_llama", LlavaConfig)
158 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM)
159 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_mistral.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 | from torch.nn import CrossEntropyLoss
21 |
22 | from transformers import AutoConfig, AutoModelForCausalLM, \
23 | MistralConfig, MistralModel, MistralForCausalLM
24 |
25 | from transformers.modeling_outputs import CausalLMOutputWithPast
26 | from transformers.generation.utils import GenerateOutput
27 |
28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
29 |
30 |
31 | class LlavaMistralConfig(MistralConfig):
32 | model_type = "llava_mistral"
33 |
34 |
35 | class LlavaMistralModel(LlavaMetaModel, MistralModel):
36 | config_class = LlavaMistralConfig
37 |
38 | def __init__(self, config: MistralConfig):
39 | super(LlavaMistralModel, self).__init__(config)
40 |
41 |
42 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM):
43 | config_class = LlavaMistralConfig
44 |
45 | def __init__(self, config):
46 | super(MistralForCausalLM, self).__init__(config)
47 | self.model = LlavaMistralModel(config)
48 |
49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
50 |
51 | # Initialize weights and apply final processing
52 | self.post_init()
53 |
54 | def get_model(self):
55 | return self.model
56 |
57 | def forward(
58 | self,
59 | input_ids: torch.LongTensor = None,
60 | attention_mask: Optional[torch.Tensor] = None,
61 | position_ids: Optional[torch.LongTensor] = None,
62 | past_key_values: Optional[List[torch.FloatTensor]] = None,
63 | inputs_embeds: Optional[torch.FloatTensor] = None,
64 | labels: Optional[torch.LongTensor] = None,
65 | use_cache: Optional[bool] = None,
66 | output_attentions: Optional[bool] = None,
67 | output_hidden_states: Optional[bool] = None,
68 | images: Optional[torch.FloatTensor] = None,
69 | image_sizes: Optional[List[List[int]]] = None,
70 | return_dict: Optional[bool] = None,
71 | ) -> Union[Tuple, CausalLMOutputWithPast]:
72 |
73 | if inputs_embeds is None:
74 | (
75 | input_ids,
76 | position_ids,
77 | attention_mask,
78 | past_key_values,
79 | inputs_embeds,
80 | labels
81 | ) = self.prepare_inputs_labels_for_multimodal(
82 | input_ids,
83 | position_ids,
84 | attention_mask,
85 | past_key_values,
86 | labels,
87 | images,
88 | image_sizes
89 | )
90 |
91 | return super().forward(
92 | input_ids=input_ids,
93 | attention_mask=attention_mask,
94 | position_ids=position_ids,
95 | past_key_values=past_key_values,
96 | inputs_embeds=inputs_embeds,
97 | labels=labels,
98 | use_cache=use_cache,
99 | output_attentions=output_attentions,
100 | output_hidden_states=output_hidden_states,
101 | return_dict=return_dict
102 | )
103 |
104 | @torch.no_grad()
105 | def generate(
106 | self,
107 | inputs: Optional[torch.Tensor] = None,
108 | images: Optional[torch.Tensor] = None,
109 | image_sizes: Optional[torch.Tensor] = None,
110 | **kwargs,
111 | ) -> Union[GenerateOutput, torch.LongTensor]:
112 | position_ids = kwargs.pop("position_ids", None)
113 | attention_mask = kwargs.pop("attention_mask", None)
114 | if "inputs_embeds" in kwargs:
115 | raise NotImplementedError("`inputs_embeds` is not supported")
116 |
117 | if images is not None:
118 | (
119 | inputs,
120 | position_ids,
121 | attention_mask,
122 | _,
123 | inputs_embeds,
124 | _
125 | ) = self.prepare_inputs_labels_for_multimodal(
126 | inputs,
127 | position_ids,
128 | attention_mask,
129 | None,
130 | None,
131 | images,
132 | image_sizes=image_sizes
133 | )
134 | else:
135 | inputs_embeds = self.get_model().embed_tokens(inputs)
136 |
137 | return super().generate(
138 | position_ids=position_ids,
139 | attention_mask=attention_mask,
140 | inputs_embeds=inputs_embeds,
141 | **kwargs
142 | )
143 |
144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None,
145 | inputs_embeds=None, **kwargs):
146 | images = kwargs.pop("images", None)
147 | image_sizes = kwargs.pop("image_sizes", None)
148 | inputs = super().prepare_inputs_for_generation(
149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
150 | )
151 | if images is not None:
152 | inputs['images'] = images
153 | if image_sizes is not None:
154 | inputs['image_sizes'] = image_sizes
155 | return inputs
156 |
157 | AutoConfig.register("llava_mistral", LlavaMistralConfig)
158 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM)
159 |
--------------------------------------------------------------------------------
/llava/model/language_model/llava_mpt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import Optional, Tuple
17 |
18 | import torch
19 |
20 | from transformers import AutoConfig, AutoModelForCausalLM, \
21 | MptConfig, MptForCausalLM, MptModel
22 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM
23 |
24 |
25 | class LlavaMptConfig(MptConfig):
26 | model_type = "llava_mpt"
27 |
28 |
29 | class LlavaMptModel(LlavaMetaModel, MptModel):
30 | config_class = LlavaMptConfig
31 |
32 | def __init__(self, config: MptConfig):
33 | config.hidden_size = config.d_model
34 | super(LlavaMptModel, self).__init__(config)
35 |
36 | def embed_tokens(self, x):
37 | return self.wte(x)
38 |
39 |
40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM):
41 | config_class = LlavaMptConfig
42 | supports_gradient_checkpointing = True
43 |
44 | def __init__(self, config):
45 | super(MptForCausalLM, self).__init__(config)
46 |
47 | self.transformer = LlavaMptModel(config)
48 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
49 |
50 | # Initialize weights and apply final processing
51 | self.post_init()
52 |
53 | def get_model(self):
54 | return self.transformer
55 |
56 | def _set_gradient_checkpointing(self, module, value=False):
57 | if isinstance(module, LlavaMptModel):
58 | module.gradient_checkpointing = value
59 |
60 | def forward(
61 | self,
62 | input_ids: Optional[torch.LongTensor] = None,
63 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
64 | attention_mask: Optional[torch.Tensor] = None,
65 | inputs_embeds: Optional[torch.Tensor] = None,
66 | labels: Optional[torch.Tensor] = None,
67 | use_cache: Optional[bool] = None,
68 | output_attentions: Optional[bool] = None,
69 | output_hidden_states: Optional[bool] = None,
70 | return_dict: Optional[bool] = None,
71 | images=None):
72 |
73 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images)
74 |
75 | return super().forward(
76 | input_ids,
77 | past_key_values=past_key_values,
78 | attention_mask=attention_mask,
79 | inputs_embeds=inputs_embeds,
80 | labels=labels,
81 | use_cache=use_cache,
82 | output_attentions=output_attentions,
83 | output_hidden_states=output_hidden_states,
84 | return_dict=return_dict,
85 | )
86 |
87 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
88 | images = kwargs.pop("images", None)
89 | _inputs = super().prepare_inputs_for_generation(
90 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs
91 | )
92 | _inputs['images'] = images
93 | return _inputs
94 |
95 |
96 | AutoConfig.register("llava_mpt", LlavaMptConfig)
97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM)
98 |
--------------------------------------------------------------------------------
/llava/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from llava.model.utils import auto_upgrade
11 |
12 |
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31 | bparam = base.state_dict()[name]
32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2
3 |
4 |
5 | def build_vision_tower(vision_tower_cfg, **kwargs):
6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
7 | is_absolute_path_exists = os.path.exists(vision_tower)
8 | use_s2 = getattr(vision_tower_cfg, 's2', False)
9 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower:
10 | if use_s2:
11 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs)
12 | else:
13 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
14 |
15 | raise ValueError(f'Unknown vision tower: {vision_tower}')
16 |
--------------------------------------------------------------------------------
/llava/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5 |
6 |
7 | class CLIPVisionTower(nn.Module):
8 | def __init__(self, vision_tower, args, delay_load=False):
9 | super().__init__()
10 |
11 | self.is_loaded = False
12 |
13 | self.vision_tower_name = vision_tower
14 | self.select_layer = args.mm_vision_select_layer
15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
16 |
17 | if not delay_load:
18 | self.load_model()
19 | elif getattr(args, 'unfreeze_mm_vision_tower', False):
20 | self.load_model()
21 | else:
22 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
23 |
24 | def load_model(self, device_map=None):
25 | if self.is_loaded:
26 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
27 | return
28 |
29 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
30 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
31 | self.vision_tower.requires_grad_(False)
32 |
33 | self.is_loaded = True
34 |
35 | def feature_select(self, image_forward_outs):
36 | image_features = image_forward_outs.hidden_states[self.select_layer]
37 | if self.select_feature == 'patch':
38 | image_features = image_features[:, 1:]
39 | elif self.select_feature == 'cls_patch':
40 | image_features = image_features
41 | else:
42 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
43 | return image_features
44 |
45 | @torch.no_grad()
46 | def forward(self, images):
47 | if type(images) is list:
48 | image_features = []
49 | for image in images:
50 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
51 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
52 | image_features.append(image_feature)
53 | else:
54 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
55 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
56 |
57 | return image_features
58 |
59 | @property
60 | def dummy_feature(self):
61 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
62 |
63 | @property
64 | def dtype(self):
65 | return self.vision_tower.dtype
66 |
67 | @property
68 | def device(self):
69 | return self.vision_tower.device
70 |
71 | @property
72 | def config(self):
73 | if self.is_loaded:
74 | return self.vision_tower.config
75 | else:
76 | return self.cfg_only
77 |
78 | @property
79 | def hidden_size(self):
80 | return self.config.hidden_size
81 |
82 | @property
83 | def num_patches_per_side(self):
84 | return self.config.image_size // self.config.patch_size
85 |
86 | @property
87 | def num_patches(self):
88 | return (self.config.image_size // self.config.patch_size) ** 2
89 |
90 |
91 |
92 | class CLIPVisionTowerS2(CLIPVisionTower):
93 | def __init__(self, vision_tower, args, delay_load=False):
94 | super().__init__(vision_tower, args, delay_load)
95 |
96 | self.s2_scales = getattr(args, 's2_scales', '336,672,1008')
97 | self.s2_scales = list(map(int, self.s2_scales.split(',')))
98 | self.s2_scales.sort()
99 | self.s2_split_size = self.s2_scales[0]
100 | self.s2_image_size = self.s2_scales[-1]
101 |
102 | try:
103 | from s2wrapper import forward as multiscale_forward
104 | except ImportError:
105 | raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git')
106 | self.multiscale_forward = multiscale_forward
107 |
108 | # change resize/crop size in preprocessing to the largest image size in s2_scale
109 | if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False):
110 | self.image_processor.size['shortest_edge'] = self.s2_image_size
111 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
112 |
113 | def load_model(self, device_map=None):
114 | if self.is_loaded:
115 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
116 | return
117 |
118 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
119 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
120 | self.vision_tower.requires_grad_(False)
121 |
122 | self.image_processor.size['shortest_edge'] = self.s2_image_size
123 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size
124 |
125 | self.is_loaded = True
126 |
127 | @torch.no_grad()
128 | def forward_feature(self, images):
129 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
130 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
131 | return image_features
132 |
133 | @torch.no_grad()
134 | def forward(self, images):
135 | if type(images) is list:
136 | image_features = []
137 | for image in images:
138 | image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
139 | image_features.append(image_feature)
140 | else:
141 | image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size)
142 |
143 | return image_features
144 |
145 | @property
146 | def hidden_size(self):
147 | return self.config.hidden_size * len(self.s2_scales)
148 |
--------------------------------------------------------------------------------
/llava/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 | def forward(self, x):
29 | x = self.pre_norm(x)
30 | return x + self.proj(x)
31 |
32 |
33 | def build_vision_projector(config, delay_load=False, **kwargs):
34 | projector_type = getattr(config, 'mm_projector_type', 'linear')
35 |
36 | if projector_type == 'linear':
37 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
38 |
39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
40 | if mlp_gelu_match:
41 | mlp_depth = int(mlp_gelu_match.group(1))
42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
43 | for _ in range(1, mlp_depth):
44 | modules.append(nn.GELU())
45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
46 | return nn.Sequential(*modules)
47 |
48 | if projector_type == 'identity':
49 | return IdentityMap()
50 |
51 | raise ValueError(f'Unknown projector type: {projector_type}')
52 |
--------------------------------------------------------------------------------
/llava/model/smplx/joint_names.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems. All rights reserved.
14 | #
15 | # Contact: ps-license@tuebingen.mpg.de
16 |
17 | JOINT_NAMES = [
18 | 'pelvis', # 0
19 | 'left_hip', # 1
20 | 'right_hip', # 2
21 | 'spine1', # 3
22 | 'left_knee', # 4
23 | 'right_knee', # 5
24 | 'spine2', # 6
25 | 'left_ankle', # 7
26 | 'right_ankle', # 8
27 | 'spine3', # 9, (10 - 1)
28 | 'left_foot', # 10
29 | 'right_foot', # 11
30 | 'neck', # 12
31 | 'left_collar', # 13
32 | 'right_collar', # 14, (15 - 1)
33 | 'head', # 15
34 | 'left_shoulder', # 16
35 | 'right_shoulder', # 17
36 | 'left_elbow', # 18
37 | 'right_elbow', # 19, (20 - 1)
38 | 'left_wrist', # 20
39 | 'right_wrist', # 21
40 | 'jaw', # 22, (23 - 1)
41 | 'left_eye_smplhf', # 23
42 | 'right_eye_smplhf', # 24, (25 - 1)
43 | 'left_index1', # 25
44 | 'left_index2', # 26
45 | 'left_index3', # 27
46 | 'left_middle1', # 28
47 | 'left_middle2', # 29
48 | 'left_middle3', # 30
49 | 'left_pinky1', # 31
50 | 'left_pinky2', # 32
51 | 'left_pinky3', # 33
52 | 'left_ring1', # 34
53 | 'left_ring2', # 35
54 | 'left_ring3', # 36
55 | 'left_thumb1', # 37
56 | 'left_thumb2', # 38
57 | 'left_thumb3', # 39, (40 - 1)
58 | 'right_index1', # 40
59 | 'right_index2', # 41
60 | 'right_index3', # 42
61 | 'right_middle1', # 43
62 | 'right_middle2', # 44
63 | 'right_middle3', # 45
64 | 'right_pinky1', # 46
65 | 'right_pinky2', # 47
66 | 'right_pinky3', # 48
67 | 'right_ring1', # 49
68 | 'right_ring2', # 50
69 | 'right_ring3', # 51
70 | 'right_thumb1', # 52
71 | 'right_thumb2', # 53
72 | 'right_thumb3', # 54, (55 - 1)
73 | 'nose', # 55
74 | 'right_eye', # 56
75 | 'left_eye', # 57
76 | 'right_ear', # 58
77 | 'left_ear', # 59
78 | 'left_big_toe', # 60
79 | 'left_small_toe', # 61
80 | 'left_heel', # 62
81 | 'right_big_toe', # 63
82 | 'right_small_toe', # 64, (65 - 1)
83 | 'right_heel', # 65
84 | 'left_thumb', # 66
85 | 'left_index', # 67
86 | 'left_middle', # 68
87 | 'left_ring', # 69, (70 - 1)
88 | 'left_pinky', # 70
89 | 'right_thumb', # 71
90 | 'right_index', # 72
91 | 'right_middle', # 73
92 | 'right_ring', # 74, (75 - 1)
93 | 'right_pinky', # 75, (76 - 1)
94 | # evaluated face jts (76 - 127)
95 | 'right_eye_brow1', # 76
96 | 'right_eye_brow2',
97 | 'right_eye_brow3',
98 | 'right_eye_brow4',
99 | 'right_eye_brow5',
100 | 'left_eye_brow5',
101 | 'left_eye_brow4',
102 | 'left_eye_brow3',
103 | 'left_eye_brow2',
104 | 'left_eye_brow1',
105 | 'nose1',
106 | 'nose2',
107 | 'nose3',
108 | 'nose4',
109 | 'right_nose_2',
110 | 'right_nose_1',
111 | 'nose_middle',
112 | 'left_nose_1',
113 | 'left_nose_2',
114 | 'right_eye1',
115 | 'right_eye2',
116 | 'right_eye3',
117 | 'right_eye4',
118 | 'right_eye5',
119 | 'right_eye6',
120 | 'left_eye4',
121 | 'left_eye3',
122 | 'left_eye2',
123 | 'left_eye1',
124 | 'left_eye6',
125 | 'left_eye5',
126 | 'right_mouth_1',
127 | 'right_mouth_2',
128 | 'right_mouth_3',
129 | 'mouth_top',
130 | 'left_mouth_3',
131 | 'left_mouth_2',
132 | 'left_mouth_1',
133 | 'left_mouth_5', # 59 in OpenPose output
134 | 'left_mouth_4', # 58 in OpenPose output
135 | 'mouth_bottom', # 116 => 116 - 76 = 40 => the 40-index item in lmk_faces_idx
136 | 'right_mouth_4',
137 | 'right_mouth_5',
138 | 'right_lip_1',
139 | 'right_lip_2',
140 | 'lip_top',
141 | 'left_lip_2',
142 | 'left_lip_1',
143 | 'left_lip_3',
144 | 'lip_bottom',
145 | 'right_lip_3',
146 | # Face contour
147 | 'right_contour_1',
148 | 'right_contour_2',
149 | 'right_contour_3',
150 | 'right_contour_4',
151 | 'right_contour_5',
152 | 'right_contour_6',
153 | 'right_contour_7',
154 | 'right_contour_8',
155 | 'contour_middle',
156 | 'left_contour_8',
157 | 'left_contour_7',
158 | 'left_contour_6',
159 | 'left_contour_5',
160 | 'left_contour_4',
161 | 'left_contour_3',
162 | 'left_contour_2',
163 | 'left_contour_1',
164 | ]
165 |
166 |
167 | SMPLH_JOINT_NAMES = [
168 | 'pelvis',
169 | 'left_hip',
170 | 'right_hip',
171 | 'spine1',
172 | 'left_knee',
173 | 'right_knee',
174 | 'spine2',
175 | 'left_ankle',
176 | 'right_ankle',
177 | 'spine3',
178 | 'left_foot', # 10
179 | 'right_foot', # 11
180 | 'neck',
181 | 'left_collar',
182 | 'right_collar',
183 | 'head',
184 | 'left_shoulder',
185 | 'right_shoulder',
186 | 'left_elbow',
187 | 'right_elbow',
188 | 'left_wrist', # 20
189 | 'right_wrist', # 21
190 | 'left_index1',
191 | 'left_index2',
192 | 'left_index3',
193 | 'left_middle1', # 25
194 | 'left_middle2',
195 | 'left_middle3',
196 | 'left_pinky1',
197 | 'left_pinky2',
198 | 'left_pinky3', # 30
199 | 'left_ring1',
200 | 'left_ring2',
201 | 'left_ring3',
202 | 'left_thumb1',
203 | 'left_thumb2',
204 | 'left_thumb3',
205 | 'right_index1',
206 | 'right_index2',
207 | 'right_index3',
208 | 'right_middle1',
209 | 'right_middle2', # 41
210 | 'right_middle3',
211 | 'right_pinky1',
212 | 'right_pinky2',
213 | 'right_pinky3',
214 | 'right_ring1',
215 | 'right_ring2',
216 | 'right_ring3',
217 | 'right_thumb1',
218 | 'right_thumb2',
219 | 'right_thumb3',
220 | 'nose',
221 | 'right_eye',
222 | 'left_eye',
223 | 'right_ear',
224 | 'left_ear',
225 | 'left_big_toe',
226 | 'left_small_toe',
227 | 'left_heel',
228 | 'right_big_toe',
229 | 'right_small_toe',
230 | 'right_heel',
231 | 'left_thumb',
232 | 'left_index',
233 | 'left_middle',
234 | 'left_ring',
235 | 'left_pinky',
236 | 'right_thumb',
237 | 'right_index',
238 | 'right_middle',
239 | 'right_ring',
240 | 'right_pinky',
241 | ]
--------------------------------------------------------------------------------
/llava/model/smplx/smplx_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | from .lbs import *
4 |
5 | def calculate_A(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents, pose2rot=True):
6 | batch_size = max(betas.shape[0], pose.shape[0])
7 | device, dtype = betas.device, betas.dtype
8 |
9 | # Add shape contribution
10 | v_shaped = v_template + blend_shapes(betas, shapedirs)
11 |
12 | # Get the joints
13 | # NxJx3 array
14 | J = vertices2joints(J_regressor, v_shaped)
15 |
16 | # 3. Add pose blend shapes
17 | # N x J x 3 x 3
18 | ident = torch.eye(3, dtype=dtype, device=device)
19 | if pose2rot:
20 | rot_mats = batch_rodrigues(pose.view(-1, 3)).view(
21 | [batch_size, -1, 3, 3])
22 |
23 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
24 | # (N x P) x (P, V * 3) -> N x V x 3
25 | pose_offsets = torch.matmul(
26 | pose_feature, posedirs).view(batch_size, -1, 3)
27 | else:
28 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
29 | rot_mats = pose.view(batch_size, -1, 3, 3)
30 |
31 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1),
32 | posedirs).view(batch_size, -1, 3)
33 |
34 | v_posed = pose_offsets + v_shaped
35 | # 4. Get the global joint location
36 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
37 | return A
--------------------------------------------------------------------------------
/llava/model/smplx/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems. All rights reserved.
14 | #
15 | # Contact: ps-license@tuebingen.mpg.de
16 |
17 | from typing import NewType, Union, Optional
18 | from dataclasses import dataclass, asdict, fields
19 | import numpy as np
20 | import torch
21 |
22 | Tensor = NewType('Tensor', torch.Tensor)
23 | Array = NewType('Array', np.ndarray)
24 |
25 |
26 | @dataclass
27 | class ModelOutput:
28 | vertices: Optional[Tensor] = None
29 | joints: Optional[Tensor] = None
30 | full_pose: Optional[Tensor] = None
31 | global_orient: Optional[Tensor] = None
32 | transl: Optional[Tensor] = None
33 | v_shaped: Optional[Tensor] = None
34 |
35 | def __getitem__(self, key):
36 | return getattr(self, key)
37 |
38 | def get(self, key, default=None):
39 | return getattr(self, key, default)
40 |
41 | def __iter__(self):
42 | return self.keys()
43 |
44 | def keys(self):
45 | keys = [t.name for t in fields(self)]
46 | return iter(keys)
47 |
48 | def values(self):
49 | values = [getattr(self, t.name) for t in fields(self)]
50 | return iter(values)
51 |
52 | def items(self):
53 | data = [(t.name, getattr(self, t.name)) for t in fields(self)]
54 | return iter(data)
55 |
56 |
57 | @dataclass
58 | class SMPLOutput(ModelOutput):
59 | betas: Optional[Tensor] = None
60 | body_pose: Optional[Tensor] = None
61 |
62 |
63 | @dataclass
64 | class SMPLHOutput(SMPLOutput):
65 | left_hand_pose: Optional[Tensor] = None
66 | right_hand_pose: Optional[Tensor] = None
67 | transl: Optional[Tensor] = None
68 |
69 |
70 | @dataclass
71 | class SMPLXOutput(SMPLHOutput):
72 | expression: Optional[Tensor] = None
73 | jaw_pose: Optional[Tensor] = None
74 |
75 |
76 | @dataclass
77 | class MANOOutput(ModelOutput):
78 | betas: Optional[Tensor] = None
79 | hand_pose: Optional[Tensor] = None
80 |
81 |
82 | @dataclass
83 | class FLAMEOutput(ModelOutput):
84 | betas: Optional[Tensor] = None
85 | expression: Optional[Tensor] = None
86 | jaw_pose: Optional[Tensor] = None
87 | neck_pose: Optional[Tensor] = None
88 |
89 |
90 | def find_joint_kin_chain(joint_id, kinematic_tree):
91 | kin_chain = []
92 | curr_idx = joint_id
93 | while curr_idx != -1:
94 | kin_chain.append(curr_idx)
95 | curr_idx = kinematic_tree[curr_idx]
96 | return kin_chain
97 |
98 |
99 | def to_tensor(
100 | array: Union[Array, Tensor], dtype=torch.float32
101 | ) -> Tensor:
102 | if torch.is_tensor(array):
103 | return array.contiguous()
104 | else:
105 | return torch.tensor(array, dtype=dtype).contiguous()
106 |
107 |
108 | class Struct(object):
109 | def __init__(self, **kwargs):
110 | for key, val in kwargs.items():
111 | setattr(self, key, val)
112 |
113 |
114 | def to_np(array, dtype=np.float32):
115 | if 'scipy.sparse' in str(type(array)):
116 | array = array.todense()
117 | return np.array(array, dtype=dtype)
118 |
119 |
120 | def rot_mat_to_euler(rot_mats):
121 | # Calculates rotation matrix to euler angles
122 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0]
123 |
124 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
125 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
126 | return torch.atan2(-rot_mats[:, 2, 0], sy)
127 |
--------------------------------------------------------------------------------
/llava/model/smplx/vertex_ids.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems. All rights reserved.
14 | #
15 | # Contact: ps-license@tuebingen.mpg.de
16 |
17 | from __future__ import print_function
18 | from __future__ import absolute_import
19 | from __future__ import division
20 |
21 | # Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to
22 | # MSCOCO and OpenPose joints
23 | vertex_ids = {
24 | 'smplh': {
25 | 'nose': 332,
26 | 'reye': 6260,
27 | 'leye': 2800,
28 | 'rear': 4071,
29 | 'lear': 583,
30 | 'rthumb': 6191,
31 | 'rindex': 5782,
32 | 'rmiddle': 5905,
33 | 'rring': 6016,
34 | 'rpinky': 6133,
35 | 'lthumb': 2746,
36 | 'lindex': 2319,
37 | 'lmiddle': 2445,
38 | 'lring': 2556,
39 | 'lpinky': 2673,
40 | 'LBigToe': 3216,
41 | 'LSmallToe': 3226,
42 | 'LHeel': 3387,
43 | 'RBigToe': 6617,
44 | 'RSmallToe': 6624,
45 | 'RHeel': 6787
46 | },
47 | 'smplx': {
48 | 'nose': 9120,
49 | 'reye': 9929,
50 | 'leye': 9448,
51 | 'rear': 616,
52 | 'lear': 6,
53 | 'rthumb': 8079,
54 | 'rindex': 7669,
55 | 'rmiddle': 7794,
56 | 'rring': 7905,
57 | 'rpinky': 8022,
58 | 'lthumb': 5361,
59 | 'lindex': 4933,
60 | 'lmiddle': 5058,
61 | 'lring': 5169,
62 | 'lpinky': 5286,
63 | 'LBigToe': 5770,
64 | 'LSmallToe': 5780,
65 | 'LHeel': 8846,
66 | 'RBigToe': 8463,
67 | 'RSmallToe': 8474,
68 | 'RHeel': 8635
69 | },
70 | 'mano': {
71 | 'thumb': 744,
72 | 'index': 320,
73 | 'middle': 443,
74 | 'ring': 554,
75 | 'pinky': 671,
76 | }
77 | }
78 |
--------------------------------------------------------------------------------
/llava/model/smplx/vertex_joint_selector.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4 | # holder of all proprietary rights on this computer program.
5 | # You can only use this computer program if you have closed
6 | # a license agreement with MPG or you get the right to use the computer
7 | # program from someone who is authorized to grant you that right.
8 | # Any use of the computer program without a valid license is prohibited and
9 | # liable to prosecution.
10 | #
11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13 | # for Intelligent Systems. All rights reserved.
14 | #
15 | # Contact: ps-license@tuebingen.mpg.de
16 |
17 | from __future__ import absolute_import
18 | from __future__ import print_function
19 | from __future__ import division
20 |
21 | import numpy as np
22 |
23 | import torch
24 | import torch.nn as nn
25 |
26 | from .utils import to_tensor
27 |
28 |
29 | class VertexJointSelector(nn.Module):
30 |
31 | def __init__(self, vertex_ids=None,
32 | use_hands=True,
33 | use_feet_keypoints=True, **kwargs):
34 | super(VertexJointSelector, self).__init__()
35 |
36 | extra_joints_idxs = []
37 |
38 | face_keyp_idxs = np.array([
39 | vertex_ids['nose'],
40 | vertex_ids['reye'],
41 | vertex_ids['leye'],
42 | vertex_ids['rear'],
43 | vertex_ids['lear']], dtype=np.int64)
44 |
45 | extra_joints_idxs = np.concatenate([extra_joints_idxs,
46 | face_keyp_idxs])
47 |
48 | if use_feet_keypoints:
49 | feet_keyp_idxs = np.array([vertex_ids['LBigToe'],
50 | vertex_ids['LSmallToe'],
51 | vertex_ids['LHeel'],
52 | vertex_ids['RBigToe'],
53 | vertex_ids['RSmallToe'],
54 | vertex_ids['RHeel']], dtype=np.int32)
55 |
56 | extra_joints_idxs = np.concatenate(
57 | [extra_joints_idxs, feet_keyp_idxs])
58 |
59 | if use_hands:
60 | self.tip_names = ['thumb', 'index', 'middle', 'ring', 'pinky']
61 |
62 | tips_idxs = []
63 | for hand_id in ['l', 'r']:
64 | for tip_name in self.tip_names:
65 | tips_idxs.append(vertex_ids[hand_id + tip_name])
66 |
67 | extra_joints_idxs = np.concatenate(
68 | [extra_joints_idxs, tips_idxs])
69 |
70 | self.register_buffer('extra_joints_idxs',
71 | to_tensor(extra_joints_idxs, dtype=torch.long))
72 |
73 | def forward(self, vertices, joints):
74 | extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs)
75 | joints = torch.cat([joints, extra_joints], dim=1)
76 |
77 | return joints
78 |
--------------------------------------------------------------------------------
/llava/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 |
3 |
4 | def auto_upgrade(config):
5 | cfg = AutoConfig.from_pretrained(config)
6 | if 'llava' in config and 'llava' not in cfg.model_type:
7 | assert cfg.model_type == 'llama'
8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
11 | if confirm.lower() in ["y", "yes"]:
12 | print("Upgrading checkpoint...")
13 | assert len(cfg.architectures) == 1
14 | setattr(cfg.__class__, "model_type", "llava")
15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM'
16 | cfg.save_pretrained(config)
17 | print("Checkpoint upgraded.")
18 | else:
19 | print("Checkpoint upgrade aborted.")
20 | exit(1)
21 |
--------------------------------------------------------------------------------
/llava/prompts_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from llava.json_fixer import repair_json
4 |
5 |
6 | def get_text_labels(gpt4o_results):
7 | result_dict = repair_json(gpt4o_results, return_objects=True)
8 |
9 | used_config_new = {}
10 | used_config_text = []
11 |
12 | if "upper garment" in result_dict:
13 | used_config_now = {
14 | 'garment_name': result_dict["upper garment"][0],
15 | 'geometry_styles': result_dict["upper garment"][1],
16 | }
17 | used_config_new['upperbody_garment'] = used_config_now
18 | used_config_text.append(
19 | result_dict["upper garment"][1] + ', ' + result_dict["upper garment"][0]
20 | )
21 |
22 | if "lower garment" in result_dict:
23 | used_config_now = {
24 | 'garment_name': result_dict["lower garment"][0],
25 | 'geometry_styles': result_dict["lower garment"][1],
26 | }
27 | used_config_new['lowerbody_garment'] = used_config_now
28 | used_config_text.append(
29 | result_dict["lower garment"][1] + ', ' + result_dict["lower garment"][0]
30 | )
31 |
32 | if "wholebody garment" in result_dict:
33 | used_config_now = {
34 | 'garment_name': result_dict["wholebody garment"][0],
35 | 'geometry_styles': result_dict["wholebody garment"][1],
36 | }
37 | used_config_new['wholebody_garment'] = used_config_now
38 | used_config_text.append(
39 | result_dict["wholebody garment"][1] + ', ' + result_dict["wholebody garment"][0]
40 | )
41 |
42 | return used_config_new, used_config_text
43 |
44 |
45 |
46 | def get_text_labels_detailed(gpt4o_results):
47 | gpt4o_results = gpt4o_results.strip()
48 | if "```" in gpt4o_results:
49 | gpt4o_results = gpt4o_results.split("```")[1]
50 | gpt4o_results = gpt4o_results.strip()
51 | if gpt4o_results.startswith('json') or gpt4o_results.startswith('Json') or gpt4o_results.startswith('JSON'):
52 | gpt4o_results = gpt4o_results[4:].strip()
53 |
54 | results = repair_json(gpt4o_results, return_objects=True)
55 | if len(results) < 2:
56 | return None
57 |
58 | if isinstance(results[1], str):
59 | try:
60 | results[1] = eval(results[1])
61 | except:
62 | print('????')
63 | return None
64 |
65 | if len(results) > 2:
66 | extra_styles = results[2].split(',')
67 | results[1]['extra'] = [item.strip() for item in extra_styles]
68 | else:
69 | results[1]['extra'] = []
70 |
71 | used_config_now = {
72 | 'garment_name': results[0],
73 | 'geometry_styles': results[1],
74 | }
75 |
76 | return used_config_now
77 |
78 |
79 | def get_text_labels_foredit(gpt4o_results):
80 | gpt4o_results = gpt4o_results.strip()
81 | if "```" in gpt4o_results:
82 | gpt4o_results = gpt4o_results.split("```")[1]
83 | gpt4o_results = gpt4o_results.strip()
84 | if gpt4o_results.startswith('json') or gpt4o_results.startswith('Json') or gpt4o_results.startswith('JSON'):
85 | gpt4o_results = gpt4o_results[4:].strip()
86 |
87 | results = repair_json(gpt4o_results, return_objects=True)
88 | return results
89 |
90 |
91 | def get_gpt4o_textgen_prompt(garment_name, garment_description):
92 | txtgen_prompt_path = 'docs/prompts/detailed_textbased_description.txt'
93 | with open(txtgen_prompt_path, 'r') as f:
94 | txtgen_prompt = f.read()
95 |
96 | txtgen_prompt = txtgen_prompt.replace('[TYPE]', garment_name)
97 | txtgen_prompt = txtgen_prompt.replace('[DESCRIPTION]', garment_description)
98 | return txtgen_prompt
99 |
100 |
101 | def get_gpt4o_edit_prompt(garment_name, prompt):
102 | edit_prompt_path = 'docs/prompts/prompt_garment_editing.txt'
103 | with open(edit_prompt_path, 'r') as f:
104 | edit_prompt = f.read()
105 |
106 | edit_prompt = edit_prompt.replace('[TYPE]', garment_name)
107 | edit_prompt = edit_prompt.replace('[DESCRIPTION]', prompt)
108 | return edit_prompt
109 |
110 |
--------------------------------------------------------------------------------
/llava/pytorch3d_render_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | from pytorch3d.structures import Meshes
6 | from pytorch3d.renderer import (
7 | PerspectiveCameras,
8 | OrthographicCameras,
9 | PointLights,
10 | RasterizationSettings,
11 | MeshRasterizer,
12 | HardPhongShader,
13 | MeshRenderer,
14 | SoftSilhouetteShader,
15 | TexturesUV,
16 | TexturesVertex,
17 | BlendParams)
18 |
19 |
20 | class TexturedIUVRenderer(nn.Module):
21 | def __init__(self,
22 | device='cuda',
23 | img_wh=256,
24 | blur_radius=0.0,
25 | faces_per_pixel=1,
26 | ):
27 |
28 | super().__init__()
29 | self.img_wh = img_wh
30 |
31 | raster_settings = RasterizationSettings(image_size=img_wh,
32 | blur_radius=blur_radius,
33 | faces_per_pixel=faces_per_pixel,)
34 |
35 | self.cameras = PerspectiveCameras()
36 | self.rasterizer = MeshRasterizer(cameras=self.cameras, raster_settings=raster_settings) # Specify camera in forward pass
37 | self.iuv_shader = SoftSilhouetteShader()
38 |
39 | self.to(device)
40 |
41 | def to(self, device):
42 | self.rasterizer.to(device)
43 | self.iuv_shader.to(device)
44 |
45 | def forward(self, vertices, faces, cam_t=None, cameras=None, focal_length=5000):
46 | img_wh = self.img_wh
47 | img_center=((img_wh * 0.5, img_wh * 0.5),)
48 | cameras = PerspectiveCameras(device=vertices.device,
49 | focal_length=focal_length,
50 | principal_point=img_center,
51 | image_size=((img_wh, img_wh),),
52 | in_ndc=False)
53 | device=vertices.device
54 |
55 | if cam_t is not None:
56 | vertices = vertices + cam_t[:, None, :]
57 |
58 | vertices = vertices * torch.tensor([-1., -1., 1.], device=device).float()
59 |
60 | textures_iuv = TexturesVertex(verts_features=torch.ones_like(vertices))
61 | meshes_iuv = Meshes(verts=vertices, faces=faces, textures=textures_iuv)
62 |
63 | # Rasterize
64 | fragments = self.rasterizer(meshes_iuv, cameras=cameras)
65 |
66 | # Render RGB and IUV outputs
67 | iuv_image = self.iuv_shader(fragments, meshes_iuv)
68 |
69 | return iuv_image
70 |
--------------------------------------------------------------------------------
/llava/serve/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/serve/__init__.py
--------------------------------------------------------------------------------
/llava/serve/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
5 | from llava.conversation import conv_templates, SeparatorStyle
6 | from llava.model.builder import load_pretrained_model
7 | from llava.utils import disable_torch_init
8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path
9 |
10 | from PIL import Image
11 |
12 | import requests
13 | from PIL import Image
14 | from io import BytesIO
15 | from transformers import TextStreamer
16 |
17 |
18 | def load_image(image_file):
19 | if image_file.startswith('http://') or image_file.startswith('https://'):
20 | response = requests.get(image_file)
21 | image = Image.open(BytesIO(response.content)).convert('RGB')
22 | else:
23 | image = Image.open(image_file).convert('RGB')
24 | return image
25 |
26 |
27 | def main(args):
28 | # Model
29 | disable_torch_init()
30 |
31 | model_name = get_model_name_from_path(args.model_path)
32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
33 |
34 | if "llama-2" in model_name.lower():
35 | conv_mode = "llava_llama_2"
36 | elif "mistral" in model_name.lower():
37 | conv_mode = "mistral_instruct"
38 | elif "v1.6-34b" in model_name.lower():
39 | conv_mode = "chatml_direct"
40 | elif "v1" in model_name.lower():
41 | conv_mode = "llava_v1"
42 | elif "mpt" in model_name.lower():
43 | conv_mode = "mpt"
44 | else:
45 | conv_mode = "llava_v0"
46 |
47 | if args.conv_mode is not None and conv_mode != args.conv_mode:
48 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
49 | else:
50 | args.conv_mode = conv_mode
51 |
52 | conv = conv_templates[args.conv_mode].copy()
53 | if "mpt" in model_name.lower():
54 | roles = ('user', 'assistant')
55 | else:
56 | roles = conv.roles
57 |
58 | image = load_image(args.image_file)
59 | image_size = image.size
60 | # Similar operation in model_worker.py
61 | image_tensor = process_images([image], image_processor, model.config)
62 | if type(image_tensor) is list:
63 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
64 | else:
65 | image_tensor = image_tensor.to(model.device, dtype=torch.float16)
66 |
67 | while True:
68 | try:
69 | inp = input(f"{roles[0]}: ")
70 | except EOFError:
71 | inp = ""
72 | if not inp:
73 | print("exit...")
74 | break
75 |
76 | print(f"{roles[1]}: ", end="")
77 |
78 | if image is not None:
79 | # first message
80 | if model.config.mm_use_im_start_end:
81 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
82 | else:
83 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
84 | image = None
85 |
86 | conv.append_message(conv.roles[0], inp)
87 | conv.append_message(conv.roles[1], None)
88 | prompt = conv.get_prompt()
89 |
90 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
91 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
92 | keywords = [stop_str]
93 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
94 |
95 | with torch.inference_mode():
96 | output_ids = model.generate(
97 | input_ids,
98 | images=image_tensor,
99 | image_sizes=[image_size],
100 | do_sample=True if args.temperature > 0 else False,
101 | temperature=args.temperature,
102 | max_new_tokens=args.max_new_tokens,
103 | streamer=streamer,
104 | use_cache=True)
105 |
106 | outputs = tokenizer.decode(output_ids[0]).strip()
107 | conv.messages[-1][-1] = outputs
108 |
109 | if args.debug:
110 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
111 |
112 |
113 | if __name__ == "__main__":
114 | parser = argparse.ArgumentParser()
115 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m")
116 | parser.add_argument("--model-base", type=str, default=None)
117 | parser.add_argument("--image-file", type=str, required=True)
118 | parser.add_argument("--device", type=str, default="cuda")
119 | parser.add_argument("--conv-mode", type=str, default=None)
120 | parser.add_argument("--temperature", type=float, default=0.2)
121 | parser.add_argument("--max-new-tokens", type=int, default=512)
122 | parser.add_argument("--load-8bit", action="store_true")
123 | parser.add_argument("--load-4bit", action="store_true")
124 | parser.add_argument("--debug", action="store_true")
125 | args = parser.parse_args()
126 | main(args)
127 |
--------------------------------------------------------------------------------
/llava/serve/examples/extreme_ironing.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/serve/examples/extreme_ironing.jpg
--------------------------------------------------------------------------------
/llava/serve/examples/waterview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/serve/examples/waterview.jpg
--------------------------------------------------------------------------------
/llava/serve/register_worker.py:
--------------------------------------------------------------------------------
1 | """
2 | Manually register workers.
3 |
4 | Usage:
5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002
6 | """
7 |
8 | import argparse
9 |
10 | import requests
11 |
12 | if __name__ == "__main__":
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument("--controller-address", type=str)
15 | parser.add_argument("--worker-name", type=str)
16 | parser.add_argument("--check-heart-beat", action="store_true")
17 | args = parser.parse_args()
18 |
19 | url = args.controller_address + "/register_worker"
20 | data = {
21 | "worker_name": args.worker_name,
22 | "check_heart_beat": args.check_heart_beat,
23 | "worker_status": None,
24 | }
25 | r = requests.post(url, json=data)
26 | assert r.status_code == 200
27 |
--------------------------------------------------------------------------------
/llava/serve/test_message.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 |
4 | import requests
5 |
6 | from llava.conversation import default_conversation
7 |
8 |
9 | def main():
10 | if args.worker_address:
11 | worker_addr = args.worker_address
12 | else:
13 | controller_addr = args.controller_address
14 | ret = requests.post(controller_addr + "/refresh_all_workers")
15 | ret = requests.post(controller_addr + "/list_models")
16 | models = ret.json()["models"]
17 | models.sort()
18 | print(f"Models: {models}")
19 |
20 | ret = requests.post(controller_addr + "/get_worker_address",
21 | json={"model": args.model_name})
22 | worker_addr = ret.json()["address"]
23 | print(f"worker_addr: {worker_addr}")
24 |
25 | if worker_addr == "":
26 | return
27 |
28 | conv = default_conversation.copy()
29 | conv.append_message(conv.roles[0], args.message)
30 | prompt = conv.get_prompt()
31 |
32 | headers = {"User-Agent": "LLaVA Client"}
33 | pload = {
34 | "model": args.model_name,
35 | "prompt": prompt,
36 | "max_new_tokens": args.max_new_tokens,
37 | "temperature": 0.7,
38 | "stop": conv.sep,
39 | }
40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers,
41 | json=pload, stream=True)
42 |
43 | print(prompt.replace(conv.sep, "\n"), end="")
44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
45 | if chunk:
46 | data = json.loads(chunk.decode("utf-8"))
47 | output = data["text"].split(conv.sep)[-1]
48 | print(output, end="\r")
49 | print("")
50 |
51 |
52 | if __name__ == "__main__":
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001")
55 | parser.add_argument("--worker-address", type=str)
56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
57 | parser.add_argument("--max-new-tokens", type=int, default=32)
58 | parser.add_argument("--message", type=str, default=
59 | "Tell me a story with more than 1000 words.")
60 | args = parser.parse_args()
61 |
62 | main()
63 |
--------------------------------------------------------------------------------
/llava/train/train_mem_garmentcode_outfit.py:
--------------------------------------------------------------------------------
1 | from llava.train.train_garmentcode_outfit import train
2 |
3 | if __name__ == "__main__":
4 | train(attn_implementation="flash_attention_2")
5 |
--------------------------------------------------------------------------------
/llava/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 |
7 | import requests
8 |
9 | from llava.constants import LOGDIR
10 |
11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
13 |
14 | handler = None
15 |
16 |
17 | def build_logger(logger_name, logger_filename):
18 | global handler
19 |
20 | formatter = logging.Formatter(
21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22 | datefmt="%Y-%m-%d %H:%M:%S",
23 | )
24 |
25 | # Set the format of root handlers
26 | if not logging.getLogger().handlers:
27 | logging.basicConfig(level=logging.INFO)
28 | logging.getLogger().handlers[0].setFormatter(formatter)
29 |
30 | # Redirect stdout and stderr to loggers
31 | stdout_logger = logging.getLogger("stdout")
32 | stdout_logger.setLevel(logging.INFO)
33 | sl = StreamToLogger(stdout_logger, logging.INFO)
34 | sys.stdout = sl
35 |
36 | stderr_logger = logging.getLogger("stderr")
37 | stderr_logger.setLevel(logging.ERROR)
38 | sl = StreamToLogger(stderr_logger, logging.ERROR)
39 | sys.stderr = sl
40 |
41 | # Get logger
42 | logger = logging.getLogger(logger_name)
43 | logger.setLevel(logging.INFO)
44 |
45 | # Add a file handler for all loggers
46 | if handler is None:
47 | os.makedirs(LOGDIR, exist_ok=True)
48 | filename = os.path.join(LOGDIR, logger_filename)
49 | handler = logging.handlers.TimedRotatingFileHandler(
50 | filename, when='D', utc=True, encoding='UTF-8')
51 | handler.setFormatter(formatter)
52 |
53 | for name, item in logging.root.manager.loggerDict.items():
54 | if isinstance(item, logging.Logger):
55 | item.addHandler(handler)
56 |
57 | return logger
58 |
59 |
60 | class StreamToLogger(object):
61 | """
62 | Fake file-like stream object that redirects writes to a logger instance.
63 | """
64 | def __init__(self, logger, log_level=logging.INFO):
65 | self.terminal = sys.stdout
66 | self.logger = logger
67 | self.log_level = log_level
68 | self.linebuf = ''
69 |
70 | def __getattr__(self, attr):
71 | return getattr(self.terminal, attr)
72 |
73 | def write(self, buf):
74 | temp_linebuf = self.linebuf + buf
75 | self.linebuf = ''
76 | for line in temp_linebuf.splitlines(True):
77 | # From the io.TextIOWrapper docs:
78 | # On output, if newline is None, any '\n' characters written
79 | # are translated to the system default line separator.
80 | # By default sys.stdout.write() expects '\n' newlines and then
81 | # translates them so this is still cross platform.
82 | if line[-1] == '\n':
83 | self.logger.log(self.log_level, line.rstrip())
84 | else:
85 | self.linebuf += line
86 |
87 | def flush(self):
88 | if self.linebuf != '':
89 | self.logger.log(self.log_level, self.linebuf.rstrip())
90 | self.linebuf = ''
91 |
92 |
93 | def disable_torch_init():
94 | """
95 | Disable the redundant torch default initialization to accelerate model creation.
96 | """
97 | import torch
98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
100 |
101 |
102 | def violates_moderation(text):
103 | """
104 | Check whether the text violates OpenAI moderation API.
105 | """
106 | url = "https://api.openai.com/v1/moderations"
107 | headers = {"Content-Type": "application/json",
108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
109 | text = text.replace("\n", "")
110 | data = "{" + '"input": ' + f'"{text}"' + "}"
111 | data = data.encode("utf-8")
112 | try:
113 | ret = requests.post(url, headers=headers, data=data, timeout=5)
114 | flagged = ret.json()["results"][0]["flagged"]
115 | except requests.exceptions.RequestException as e:
116 | flagged = False
117 | except KeyError as e:
118 | flagged = False
119 |
120 | return flagged
121 |
122 |
123 | def pretty_print_semaphore(semaphore):
124 | if semaphore is None:
125 | return "None"
126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
127 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "llava"
7 | version = "1.2.2.post1"
8 | description = "Towards GPT-4 like large language and visual assistant."
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "torch==2.1.2", "torchvision==0.16.2",
17 | "transformers==4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid",
18 | "accelerate==0.32.0", "peft", "bitsandbytes",
19 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
20 | "gradio==4.16.0", "gradio_client==0.8.1",
21 | "requests", "httpx", "uvicorn", "fastapi",
22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
23 | "opencv-python", "easydict", "tensorboard", "peft==0.10.0"
24 | ]
25 |
26 | [project.optional-dependencies]
27 | train = ["deepspeed==0.12.6", "ninja", "wandb"]
28 | build = ["build", "twine"]
29 |
30 | [project.urls]
31 | "Homepage" = "https://chatgarment.github.io/"
32 | "Bug Tracker" = "https://github.com/biansy000/ChatGarment/issues"
33 |
34 | [tool.setuptools.packages.find]
35 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
36 |
37 | [tool.wheel]
38 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
39 |
--------------------------------------------------------------------------------
/run_garmentcode_sim.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import argparse
4 | import json
5 | from pathlib import Path
6 |
7 | # add the path of GarmentCode
8 | sys.path.insert(1, '/is/cluster/fast/sbian/github/GarmentCodeV2/')
9 | from assets.garment_programs.meta_garment import MetaGarment
10 | from assets.bodies.body_params import BodyParameters
11 |
12 | def run_simultion_warp(pattern_spec, sim_config, output_path, easy_texture_path):
13 | from pygarment.meshgen.boxmeshgen import BoxMesh
14 | from pygarment.meshgen.simulation import run_sim
15 | import pygarment.data_config as data_config
16 | from pygarment.meshgen.sim_config import PathCofig
17 |
18 | props = data_config.Properties(sim_config)
19 | props.set_section_stats('sim', fails={}, sim_time={}, spf={}, fin_frame={}, body_collisions={}, self_collisions={})
20 | props.set_section_stats('render', render_time={})
21 |
22 | spec_path = Path(pattern_spec)
23 | garment_name, _, _ = spec_path.stem.rpartition('_') # assuming ending in '_specification'
24 |
25 | paths = PathCofig(
26 | in_element_path=spec_path.parent,
27 | out_path=output_path,
28 | in_name=garment_name,
29 | body_name='mean_all', # 'f_smpl_average_A40'
30 | smpl_body=False, # NOTE: depends on chosen body model
31 | add_timestamp=False,
32 | system_path='/is/cluster/fast/sbian/github/GarmentCodeV2/system.json',
33 | easy_texture_path=easy_texture_path
34 | )
35 |
36 | # Generate and save garment box mesh (if not existent)
37 | print(f"Generate box mesh of {garment_name} with resolution {props['sim']['config']['resolution_scale']}...")
38 | print('\nGarment load: ', paths.in_g_spec)
39 |
40 | garment_box_mesh = BoxMesh(paths.in_g_spec, props['sim']['config']['resolution_scale'])
41 | garment_box_mesh.load()
42 | garment_box_mesh.serialize(
43 | paths, store_panels=False, uv_config=props['render']['config']['uv_texture'])
44 |
45 | props.serialize(paths.element_sim_props)
46 |
47 | run_sim(
48 | garment_box_mesh.name,
49 | props,
50 | paths,
51 | save_v_norms=False,
52 | store_usd=False, # NOTE: False for fast simulation!
53 | optimize_storage=False, # props['sim']['config']['optimize_storage'],
54 | verbose=False
55 | )
56 |
57 | props.serialize(paths.element_sim_props)
58 |
59 |
60 | parser = argparse.ArgumentParser()
61 | parser.add_argument("--all_paths_json", type=str, default='', help="path to the save resules shapenet dataset")
62 | parser.add_argument("--json_spec_file", type=str, default='', help="path to the save resules shapenet dataset")
63 | parser.add_argument("--easy_texture_path", type=str, default='', help="path to the save resules shapenet dataset")
64 | args = parser.parse_args()
65 |
66 | if len(args.all_paths_json) > 1:
67 | garment_json_path = os.path.join(args.all_paths_json, 'vis_new/all_json_spec_files.json')
68 |
69 | with open(garment_json_path) as f:
70 | garment_json_paths = json.load(f)
71 |
72 | elif args.json_spec_file:
73 | garment_json_paths = [args.json_spec_file]
74 |
75 | print(len(garment_json_paths))
76 | for json_spec_file in garment_json_paths:
77 | print(json_spec_file)
78 | json_spec_file = json_spec_file.replace('validate_garment', 'valid_garment')
79 | saved_folder = os.path.dirname(json_spec_file)
80 | run_simultion_warp(
81 | json_spec_file,
82 | 'assets/Sim_props/default_sim_props.yaml',
83 | saved_folder,
84 | easy_texture_path=args.easy_texture_path
85 | )
86 |
--------------------------------------------------------------------------------
/scripts/v1_5/evaluate_garment_v2_demo_edit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export LD_LIBRARY_PATH=/is/software/nvidia/cuda-12.1/lib64
4 | export PATH=$PATH:/is/software/nvidia/cuda-12.1/bin
5 | export CUDA_HOME=/is/software/nvidia/cuda-12.1
6 |
7 | export CPATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include
8 | export C_INCLUDE_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include
9 | export LIBRARY_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/lib64
10 | export LD_LIBRARY_PATH=$LIBRARY_PATH:$LD_LIBRARY_PATH
11 |
12 | export EGL_DEVICE_ID=$GPU_DEVICE_ORDINAL
13 | # export TCNN_CUDA_ARCHITECTURES=80
14 |
15 | deepspeed scripts/evaluate_garment_v2_demo_edit_1float.py \
16 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
17 | --deepspeed ./scripts/zero2.json \
18 | --model_name_or_path liuhaotian/llava-v1.5-7b \
19 | --version v1 \
20 | --data_path ./ \
21 | --data_path_eval example_data/example_jsons/example_edit_prompts.json \
22 | --image_folder ./ \
23 | --vision_tower openai/clip-vit-large-patch14-336 \
24 | --mm_projector_type mlp2x_gelu \
25 | --mm_vision_select_layer -2 \
26 | --mm_use_im_start_end False \
27 | --mm_use_im_patch_token False \
28 | --image_aspect_ratio pad \
29 | --group_by_modality_length True \
30 | --bf16 True \
31 | --output_dir ./checkpoints/llava-v1.5-7b-task-lora \
32 | --num_train_epochs 1 \
33 | --per_device_train_batch_size 16 \
34 | --per_device_eval_batch_size 4 \
35 | --gradient_accumulation_steps 1 \
36 | --evaluation_strategy "no" \
37 | --save_strategy "steps" \
38 | --save_steps 50000 \
39 | --save_total_limit 1 \
40 | --learning_rate 2e-4 \
41 | --weight_decay 0. \
42 | --warmup_ratio 0.03 \
43 | --lr_scheduler_type "cosine" \
44 | --logging_steps 1 \
45 | --tf32 True \
46 | --model_max_length 3072 \
47 | --gradient_checkpointing True \
48 | --dataloader_num_workers 4 \
49 | --lazy_preprocess True \
50 | --report_to wandb
51 |
52 |
--------------------------------------------------------------------------------
/scripts/v1_5/evaluate_garment_v2_eva_edit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export LD_LIBRARY_PATH=/is/software/nvidia/cuda-12.1/lib64
4 | export PATH=$PATH:/is/software/nvidia/cuda-12.1/bin
5 | export CUDA_HOME=/is/software/nvidia/cuda-12.1
6 |
7 | export CPATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include
8 | export C_INCLUDE_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include
9 | export LIBRARY_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/lib64
10 | export LD_LIBRARY_PATH=$LIBRARY_PATH:$LD_LIBRARY_PATH
11 |
12 | export EGL_DEVICE_ID=$GPU_DEVICE_ORDINAL
13 | # export TCNN_CUDA_ARCHITECTURES=80
14 |
15 | deepspeed scripts/evaluate_garment_v2_eva_edit_1float.py \
16 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
17 | --deepspeed ./scripts/zero2.json \
18 | --model_name_or_path liuhaotian/llava-v1.5-7b \
19 | --version v1 \
20 | --data_path ./ \
21 | --data_path_eval data/llava_preprocess.json \
22 | --image_folder ./ \
23 | --vision_tower openai/clip-vit-large-patch14-336 \
24 | --mm_projector_type mlp2x_gelu \
25 | --mm_vision_select_layer -2 \
26 | --mm_use_im_start_end False \
27 | --mm_use_im_patch_token False \
28 | --image_aspect_ratio pad \
29 | --group_by_modality_length True \
30 | --bf16 True \
31 | --output_dir ./checkpoints/llava-v1.5-7b-task-lora \
32 | --num_train_epochs 1 \
33 | --per_device_train_batch_size 16 \
34 | --per_device_eval_batch_size 4 \
35 | --gradient_accumulation_steps 1 \
36 | --evaluation_strategy "no" \
37 | --save_strategy "steps" \
38 | --save_steps 50000 \
39 | --save_total_limit 1 \
40 | --learning_rate 2e-4 \
41 | --weight_decay 0. \
42 | --warmup_ratio 0.03 \
43 | --lr_scheduler_type "cosine" \
44 | --logging_steps 1 \
45 | --tf32 True \
46 | --model_max_length 3072 \
47 | --gradient_checkpointing True \
48 | --dataloader_num_workers 4 \
49 | --lazy_preprocess True \
50 | --report_to wandb
51 |
52 |
--------------------------------------------------------------------------------
/scripts/v1_5/evaluate_garment_v2_imggen_2step.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export LD_LIBRARY_PATH=/is/software/nvidia/cuda-12.1/lib64
4 | export PATH=$PATH:/is/software/nvidia/cuda-12.1/bin
5 | export CUDA_HOME=/is/software/nvidia/cuda-12.1
6 |
7 | export CPATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include
8 | export C_INCLUDE_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include
9 | export LIBRARY_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/lib64
10 | export LD_LIBRARY_PATH=$LIBRARY_PATH:$LD_LIBRARY_PATH
11 |
12 | export EGL_DEVICE_ID=$GPU_DEVICE_ORDINAL
13 | # export TCNN_CUDA_ARCHITECTURES=80
14 |
15 | deepspeed scripts/evaluate_garment_v2_imggen_1float.py \
16 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
17 | --deepspeed ./scripts/zero2.json \
18 | --model_name_or_path liuhaotian/llava-v1.5-7b \
19 | --version v1 \
20 | --data_path ./ \
21 | --data_path_eval $1 \
22 | --image_folder ./ \
23 | --vision_tower openai/clip-vit-large-patch14-336 \
24 | --mm_projector_type mlp2x_gelu \
25 | --mm_vision_select_layer -2 \
26 | --mm_use_im_start_end False \
27 | --mm_use_im_patch_token False \
28 | --image_aspect_ratio pad \
29 | --group_by_modality_length True \
30 | --bf16 True \
31 | --output_dir ./checkpoints/llava-v1.5-7b-task-lora \
32 | --num_train_epochs 1 \
33 | --per_device_train_batch_size 16 \
34 | --per_device_eval_batch_size 4 \
35 | --gradient_accumulation_steps 1 \
36 | --evaluation_strategy "no" \
37 | --save_strategy "steps" \
38 | --save_steps 50000 \
39 | --save_total_limit 1 \
40 | --learning_rate 2e-4 \
41 | --weight_decay 0. \
42 | --warmup_ratio 0.03 \
43 | --lr_scheduler_type "cosine" \
44 | --logging_steps 1 \
45 | --tf32 True \
46 | --model_max_length 3072 \
47 | --gradient_checkpointing True \
48 | --dataloader_num_workers 4 \
49 | --lazy_preprocess True \
50 | --report_to wandb
51 |
52 |
--------------------------------------------------------------------------------
/scripts/v1_5/evaluate_garment_v2_textgen.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export LD_LIBRARY_PATH=/is/software/nvidia/cuda-12.1/lib64
4 | export PATH=$PATH:/is/software/nvidia/cuda-12.1/bin
5 | export CUDA_HOME=/is/software/nvidia/cuda-12.1
6 |
7 | export CPATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include
8 | export C_INCLUDE_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include
9 | export LIBRARY_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/lib64
10 | export LD_LIBRARY_PATH=$LIBRARY_PATH:$LD_LIBRARY_PATH
11 |
12 | export EGL_DEVICE_ID=$GPU_DEVICE_ORDINAL
13 | # export TCNN_CUDA_ARCHITECTURES=80
14 |
15 | deepspeed scripts/evaluate_garment_v2_textgen_1float.py \
16 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
17 | --deepspeed ./scripts/zero2.json \
18 | --model_name_or_path liuhaotian/llava-v1.5-7b \
19 | --version v1 \
20 | --data_path ./ \
21 | --data_path_eval $1 \
22 | --image_folder ./ \
23 | --vision_tower openai/clip-vit-large-patch14-336 \
24 | --mm_projector_type mlp2x_gelu \
25 | --mm_vision_select_layer -2 \
26 | --mm_use_im_start_end False \
27 | --mm_use_im_patch_token False \
28 | --image_aspect_ratio pad \
29 | --group_by_modality_length True \
30 | --bf16 True \
31 | --output_dir ./checkpoints/llava-v1.5-7b-task-lora \
32 | --num_train_epochs 1 \
33 | --per_device_train_batch_size 16 \
34 | --per_device_eval_batch_size 4 \
35 | --gradient_accumulation_steps 1 \
36 | --evaluation_strategy "no" \
37 | --save_strategy "steps" \
38 | --save_steps 50000 \
39 | --save_total_limit 1 \
40 | --learning_rate 2e-4 \
41 | --weight_decay 0. \
42 | --warmup_ratio 0.03 \
43 | --lr_scheduler_type "cosine" \
44 | --logging_steps 1 \
45 | --tf32 True \
46 | --model_max_length 3072 \
47 | --gradient_checkpointing True \
48 | --dataloader_num_workers 4 \
49 | --lazy_preprocess True \
50 | --report_to wandb
51 |
52 |
--------------------------------------------------------------------------------
/scripts/v1_5/evaluate_garment_v2_textgen_fromimg.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export LD_LIBRARY_PATH=/is/software/nvidia/cuda-12.1/lib64
4 | export PATH=$PATH:/is/software/nvidia/cuda-12.1/bin
5 | export CUDA_HOME=/is/software/nvidia/cuda-12.1
6 |
7 | export CPATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include
8 | export C_INCLUDE_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include
9 | export LIBRARY_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/lib64
10 | export LD_LIBRARY_PATH=$LIBRARY_PATH:$LD_LIBRARY_PATH
11 |
12 | export EGL_DEVICE_ID=$GPU_DEVICE_ORDINAL
13 | # export TCNN_CUDA_ARCHITECTURES=80
14 |
15 | deepspeed scripts/evaluate_garment_v2_textgen_fromimg_1float.py \
16 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
17 | --deepspeed ./scripts/zero2.json \
18 | --model_name_or_path liuhaotian/llava-v1.5-7b \
19 | --version v1 \
20 | --data_path ./ \
21 | --data_path_eval $1 \
22 | --image_folder ./ \
23 | --vision_tower openai/clip-vit-large-patch14-336 \
24 | --mm_projector_type mlp2x_gelu \
25 | --mm_vision_select_layer -2 \
26 | --mm_use_im_start_end False \
27 | --mm_use_im_patch_token False \
28 | --image_aspect_ratio pad \
29 | --group_by_modality_length True \
30 | --bf16 True \
31 | --output_dir ./checkpoints/llava-v1.5-7b-task-lora \
32 | --num_train_epochs 1 \
33 | --per_device_train_batch_size 16 \
34 | --per_device_eval_batch_size 4 \
35 | --gradient_accumulation_steps 1 \
36 | --evaluation_strategy "no" \
37 | --save_strategy "steps" \
38 | --save_steps 50000 \
39 | --save_total_limit 1 \
40 | --learning_rate 2e-4 \
41 | --weight_decay 0. \
42 | --warmup_ratio 0.03 \
43 | --lr_scheduler_type "cosine" \
44 | --logging_steps 1 \
45 | --tf32 True \
46 | --model_max_length 3072 \
47 | --gradient_checkpointing True \
48 | --dataloader_num_workers 4 \
49 | --lazy_preprocess True \
50 | --report_to wandb
51 |
52 |
--------------------------------------------------------------------------------
/scripts/v1_5/finetune_task_lora_garmentcode_outfit.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export LD_LIBRARY_PATH=/is/software/nvidia/cuda-12.1/lib64
4 | export PATH=$PATH:/is/software/nvidia/cuda-12.1/bin
5 | export CUDA_HOME=/is/software/nvidia/cuda-12.1
6 |
7 | export CPATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include
8 | export C_INCLUDE_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include
9 | export LIBRARY_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/lib64
10 | export LD_LIBRARY_PATH=$LIBRARY_PATH:$LD_LIBRARY_PATH
11 |
12 | export EGL_DEVICE_ID=$GPU_DEVICE_ORDINAL
13 | # export TCNN_CUDA_ARCHITECTURES=80
14 |
15 | deepspeed llava/train/train_mem_garmentcode_outfit.py \
16 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
17 | --deepspeed ./scripts/zero2.json \
18 | --model_name_or_path liuhaotian/llava-v1.5-7b \
19 | --version v1 \
20 | --data_path ./ \
21 | --data_path_eval ./ \
22 | --image_folder ./ \
23 | --vision_tower openai/clip-vit-large-patch14-336 \
24 | --mm_projector_type mlp2x_gelu \
25 | --mm_vision_select_layer -2 \
26 | --mm_use_im_start_end False \
27 | --mm_use_im_patch_token False \
28 | --image_aspect_ratio pad \
29 | --group_by_modality_length True \
30 | --bf16 True \
31 | --output_dir ./checkpoints/llava-v1.5-7b-task-lora \
32 | --num_train_epochs 1 \
33 | --per_device_train_batch_size 16 \
34 | --per_device_eval_batch_size 4 \
35 | --gradient_accumulation_steps 1 \
36 | --evaluation_strategy "no" \
37 | --save_strategy "steps" \
38 | --save_steps 50000 \
39 | --save_total_limit 1 \
40 | --learning_rate 2e-4 \
41 | --weight_decay 0. \
42 | --warmup_ratio 0.03 \
43 | --lr_scheduler_type "cosine" \
44 | --logging_steps 1 \
45 | --tf32 True \
46 | --model_max_length 3072 \
47 | --gradient_checkpointing True \
48 | --dataloader_num_workers 4 \
49 | --lazy_preprocess True \
50 | --report_to wandb
51 |
52 |
--------------------------------------------------------------------------------
/scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto"
22 | }
23 | }
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
--------------------------------------------------------------------------------
/scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "scheduler": {
23 | "type": "WarmupLR",
24 | "params": {
25 | "warmup_min_lr": "auto",
26 | "warmup_max_lr": "auto",
27 | "warmup_num_steps": "auto"
28 | }
29 | },
30 | "zero_optimization": {
31 | "stage": 3,
32 | "offload_optimizer": {
33 | "device": "cpu",
34 | "pin_memory": true
35 | },
36 | "offload_param": {
37 | "device": "cpu",
38 | "pin_memory": true
39 | },
40 | "overlap_comm": true,
41 | "contiguous_gradients": true,
42 | "sub_group_size": 1e9,
43 | "reduce_bucket_size": "auto",
44 | "stage3_prefetch_bucket_size": "auto",
45 | "stage3_param_persistence_threshold": "auto",
46 | "stage3_max_live_parameters": 1e9,
47 | "stage3_max_reuse_distance": 1e9,
48 | "gather_16bit_weights_on_model_save": true
49 | },
50 | "gradient_accumulation_steps": "auto",
51 | "gradient_clipping": "auto",
52 | "train_batch_size": "auto",
53 | "train_micro_batch_size_per_gpu": "auto",
54 | "steps_per_print": 1e5,
55 | "wall_clock_breakdown": false
56 | }
--------------------------------------------------------------------------------