├── .dockerignore ├── .editorconfig ├── .gitattributes ├── .github └── ISSUE_TEMPLATE │ ├── 1-usage.yaml │ ├── 2-feature-request.yaml │ ├── 3-question.yaml │ └── 4-discussion.yaml ├── .gitignore ├── LICENSE ├── README.md ├── cog.yaml ├── docs ├── Installation.md ├── all_float_paths.json ├── images │ ├── black_img.jpg │ ├── img_recon.gif │ ├── teaser.png │ ├── text_generation.png │ ├── video_edit.gif │ └── video_edit_2.gif ├── postprocess.md └── prompts │ ├── detailed_textbased_description.txt │ ├── gpt4v_prompt_garment_sam.txt │ ├── prompt_garment_editing.txt │ ├── prompt_garment_part_inference.txt │ └── smplified_image_description.txt ├── example_data ├── example_imgs │ ├── 1aee14a8c7b4d56b4e8b6ddd575d1f561a72fdc75c43a4b6926f1655152193c6.png │ ├── 1dde6afed43187fe927089a615e3f744724ef3defdf3f2ae4a6cede5ad71dcea.png │ ├── 62bb809fc2dcd50409cb36163a0eb222f9aa1af0f256a3233b67b3ed4081dc71.png │ ├── 6fe14e1f646513ee93714fbe8026a84c6a2897be4df2f3c936cb2be8dd2d1762.png │ ├── 72b086429d2dfe2a8de6f4403a024b2bb17446021c9e8f9ebacfc7a990ac8434.png │ ├── 80141ce740f489f1d2f57a03f32c7577a28b62a6ac790a0d9ed8a18d961c2918.png │ ├── 8e3c458da20c290c216813ec07f1a2e8f9cfb4ee7e412a783a238ec353b346a0.png │ ├── c2b582eb318455abaf8ed8e3126c1b423ade2704d810f7cd24428febda5632fa.png │ ├── d77c6f5d4856831878eadb7fe3c8b180bfa9e9ad4a14936ac10a1697bb3c054f.png │ └── e918651cc154a7570e47d8b8f6c0f0f93cfbb7d5129103a1bacd8299ba945f91.png ├── example_jsons │ ├── example_edit_prompts.json │ └── example_textgen_prompts.json └── example_sewing_patterns │ └── example_shirt │ ├── design.yaml │ └── valid_garment_upper_render_front.png ├── llava ├── __init__.py ├── close_utils.py ├── constants.py ├── conversation.py ├── eval │ ├── eval_gpt_review.py │ ├── eval_gpt_review_bench.py │ ├── eval_gpt_review_visual.py │ ├── eval_pope.py │ ├── eval_science_qa.py │ ├── eval_science_qa_gpt4.py │ ├── eval_science_qa_gpt4_requery.py │ ├── eval_textvqa.py │ ├── generate_webpage_data_from_table.py │ ├── m4c_evaluator.py │ ├── model_qa.py │ ├── model_vqa.py │ ├── model_vqa_loader.py │ ├── model_vqa_mmbench.py │ ├── model_vqa_science.py │ ├── qa_baseline_gpt35.py │ ├── run_llava.py │ ├── summarize_gpt_review.py │ ├── table │ │ ├── answer │ │ │ ├── answer_alpaca-13b.jsonl │ │ │ ├── answer_bard.jsonl │ │ │ ├── answer_gpt35.jsonl │ │ │ ├── answer_llama-13b.jsonl │ │ │ └── answer_vicuna-13b.jsonl │ │ ├── caps_boxes_coco2014_val_80.jsonl │ │ ├── model.jsonl │ │ ├── prompt.jsonl │ │ ├── question.jsonl │ │ ├── results │ │ │ ├── test_sqa_llava_13b_v0.json │ │ │ └── test_sqa_llava_lcs_558k_sqa_12e_vicuna_v1_3_13b.json │ │ ├── review │ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl │ │ │ ├── review_bard_vicuna-13b.jsonl │ │ │ ├── review_gpt35_vicuna-13b.jsonl │ │ │ └── review_llama-13b_vicuna-13b.jsonl │ │ ├── reviewer.jsonl │ │ └── rule.json │ └── webpage │ │ ├── figures │ │ ├── alpaca.png │ │ ├── bard.jpg │ │ ├── chatgpt.svg │ │ ├── llama.jpg │ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg │ │ └── vicuna.jpeg │ │ ├── index.html │ │ ├── script.js │ │ └── styles.css ├── garment_inquire_utils.py ├── garment_lbs_utils.py ├── garment_utils_v2.py ├── garmentcodeRC_utils.py ├── garmentcode_utils.py ├── json_fixer.py ├── lisa_utils.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_garment_float50.py │ │ ├── llava_llama.py │ │ ├── llava_mistral.py │ │ └── llava_mpt.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ ├── smplx │ │ ├── body_models.py │ │ ├── joint_names.py │ │ ├── lbs.py │ │ ├── smplx_utils.py │ │ ├── utils.py │ │ ├── vertex_ids.py │ │ └── vertex_joint_selector.py │ └── utils.py ├── prompts_utils.py ├── pytorch3d_render_utils.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ ├── sglang_worker.py │ └── test_message.py ├── train │ ├── train_garmentcode_outfit.py │ └── train_mem_garmentcode_outfit.py └── utils.py ├── pyproject.toml ├── run_garmentcode_sim.py └── scripts ├── evaluate_garment_v2_demo_edit_1float.py ├── evaluate_garment_v2_eva_edit_1float.py ├── evaluate_garment_v2_imggen_1float.py ├── evaluate_garment_v2_textgen_1float.py ├── evaluate_garment_v2_textgen_fromimg_1float.py ├── postprocess ├── grounding_sam.py └── postprocess.py ├── v1_5 ├── evaluate_garment_v2_demo_edit.sh ├── evaluate_garment_v2_eva_edit.sh ├── evaluate_garment_v2_imggen_2step.sh ├── evaluate_garment_v2_textgen.sh ├── evaluate_garment_v2_textgen_fromimg.sh └── finetune_task_lora_garmentcode_outfit.sh ├── zero2.json ├── zero3.json └── zero3_offload.json /.dockerignore: -------------------------------------------------------------------------------- 1 | # The .dockerignore file excludes files from the container build process. 2 | # 3 | # https://docs.docker.com/engine/reference/builder/#dockerignore-file 4 | 5 | # Exclude Git files 6 | .git 7 | .github 8 | .gitignore 9 | 10 | # Exclude Python cache files 11 | __pycache__ 12 | .mypy_cache 13 | .pytest_cache 14 | .ruff_cache 15 | 16 | # Exclude Python virtual environment 17 | /venv 18 | 19 | # Exclude some weights 20 | /openai 21 | /liuhaotian 22 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | # Unix-style newlines with a newline ending every file 4 | [*] 5 | end_of_line = lf 6 | insert_final_newline = true 7 | trim_trailing_whitespace = true 8 | charset = utf-8 9 | 10 | # 4 space indentation 11 | [*.{py,json}] 12 | indent_style = space 13 | indent_size = 4 14 | 15 | # 2 space indentation 16 | [*.{md,sh,yaml,yml}] 17 | indent_style = space 18 | indent_size = 2 -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # https://git-scm.com/docs/gitattributes 2 | 3 | # Set the default behavior, in case people don't have core.autocrlf set. 4 | # https://git-scm.com/docs/gitattributes#_end_of_line_conversion 5 | * text=auto 6 | 7 | # common python attributes, taken from https://github.com/alexkaratarakis/gitattributes/blob/710900479a2bedeec7003d381719521ffbb18bf8/Python.gitattributes 8 | # Source files 9 | # ============ 10 | *.pxd text diff=python 11 | *.py text diff=python 12 | *.py3 text diff=python 13 | *.pyw text diff=python 14 | *.pyx text diff=python 15 | *.pyz text diff=python 16 | *.pyi text diff=python 17 | 18 | # Binary files 19 | # ============ 20 | *.db binary 21 | *.p binary 22 | *.pkl binary 23 | *.pickle binary 24 | *.pyc binary export-ignore 25 | *.pyo binary export-ignore 26 | *.pyd binary 27 | 28 | # Jupyter notebook 29 | *.ipynb text eol=lf 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/1-usage.yaml: -------------------------------------------------------------------------------- 1 | name: Usage issues 2 | description: Report issues in usage. 3 | title: "[Usage] " 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for taking the time to fill out this form. Please give as detailed description as possible for us to better assist with the issue :) 9 | - type: textarea 10 | id: what-happened 11 | attributes: 12 | label: Describe the issue 13 | description: Please give as detailed description as possible for us to better assist with the issue. Please paste the **FULL** error log here, so that we can better understand the issue. Wrap the log with ``` for better readability in GitHub. 14 | placeholder: Issue 15 | value: | 16 | Issue: 17 | 18 | Command: 19 | ``` 20 | PASTE THE COMMANDS HERE. 21 | ``` 22 | 23 | Log: 24 | ``` 25 | PASTE THE LOGS HERE. 26 | ``` 27 | 28 | Screenshots: 29 | You may attach screenshots if it better explains the issue. 30 | validations: 31 | required: true 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/2-feature-request.yaml: -------------------------------------------------------------------------------- 1 | name: Feature Request 2 | description: Request for a new feature 3 | title: "[Feature request] " 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for your interest in our work. Please share your thoughts of the new features below. 9 | - type: textarea 10 | id: feature 11 | attributes: 12 | label: feature 13 | placeholder: Start your thoughts here... -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/3-question.yaml: -------------------------------------------------------------------------------- 1 | name: Questions 2 | description: General questions about the work 3 | title: "[Question] " 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for your interest in our work. For this type of question, it may be more suitable to go to [discussion](https://github.com/haotian-liu/LLaVA/discussions) sections. If you believe an issue would be better for your request, please continue your post below :) 9 | - type: textarea 10 | id: question 11 | attributes: 12 | label: Question 13 | placeholder: Start question here... -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/4-discussion.yaml: -------------------------------------------------------------------------------- 1 | name: Discussions 2 | description: General discussions about the work 3 | title: "[Discussion] " 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for your interest in our work. For this type of question, it may be more suitable to go to [discussion](https://github.com/haotian-liu/LLaVA/discussions) sections. If you believe an issue would be better for your request, please continue your post below :) 9 | - type: textarea 10 | id: discussion 11 | attributes: 12 | label: Discussion 13 | placeholder: Start discussion here... -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | *.pyc 4 | *.egg-info 5 | dist 6 | 7 | # Log 8 | *.log 9 | *.log.* 10 | *.jsonl 11 | 12 | # Data 13 | !**/alpaca-data-conversation.json 14 | 15 | # Editor 16 | .idea 17 | *.swp 18 | 19 | # Other 20 | .DS_Store 21 | wandb 22 | output 23 | 24 | checkpoints 25 | ckpts* 26 | 27 | .ipynb_checkpoints 28 | *.ipynb 29 | 30 | # DevContainer 31 | !.devcontainer/* 32 | 33 | # Demo 34 | serve_images/ 35 | 36 | runs 37 | assets 38 | playground -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |

4 |

ChatGarment: Garment Estimation, Generation and Editing via Large Language Models 5 |

6 |
7 | teaser 8 |
9 |

10 | 11 | 12 | This is the implementation of ChatGarment. More details please check our 13 | [[Project Page](https://chatgarment.github.io/)]. 14 | 15 | ChatGarmen utilizes large vision-language models (VLMs) to automate the estimation, generation, and editing of 3D garments from images or text descriptions. 16 | 17 | 18 | ## Applications 19 | 20 | | ![](docs/images/img_recon.gif) | | 21 | | :--------------------: | :----------: | 22 | | Image-based Reconstruction | Text-based Generation | 23 | | ![](docs/images/video_edit.gif) | ![](docs/images/video_edit_2.gif) | 24 | | Text-based Editing | Text-based Editing | 25 | 26 | 27 | ## Relevant Repositories 28 | 1. [**GarmentCodeRC**](https://github.com/biansy000/GarmentCodeRC): A refined version of the original [GarmentCode](https://github.com/maria-korosteleva/GarmentCode), used by ChatGarment for garment generation. 29 | 30 | 2. [**ContourCraft-CG**](https://github.com/biansy000/ContourCraft-CG): A refined version of the original [ContourCraft](https://github.com/Dolorousrtur/ContourCraft), used by ChatGarment for garment simulation. 31 | 32 | 3. [**ChatGarmentDataset**](https://huggingface.co/datasets/sy000/ChatGarmentDataset): A Hugging Face dataset with training and inference data used in our paper. 33 | 34 | 35 | ## Installation 36 | The installation instructions are provided in ``docs/Installation.md``. 37 | 38 | ## Model Training 39 | The training data is available in [ChatGarmentDataset](https://huggingface.co/datasets/sy000/ChatGarmentDataset). 40 | ```Shell 41 | ./scripts/v1_5/finetune_task_lora_garmentcode_wholebody_combineT2.sh 42 | ``` 43 | 44 | ## Model Inference 45 | 46 | #### 1. Image-based Reconstruction (CoT) 47 | ```Shell 48 | # Run image based reconstruction with CoT for images in example_data/example_imgs/ 49 | # Detailed steps of the script: 50 | # 1. Accepts an input image. 51 | # 2. Utilizes ChatGarment Model to generate text prompts based on the image. 52 | # 3. Sends the ChatGarment-generated text & input image to ChatGarment Model again. 53 | # 4. Outputs the final GarmentCode sewing patterns. 54 | ./scripts/v1_5/evaluate_garment_v2_imggen_2step.sh example_data/example_imgs/ 55 | ``` 56 | 57 | 58 | #### 2. Text-based Generation 59 | ```Shell 60 | # Run text based generation for prompts given in the input JSON file 61 | # Detailed steps of the script: 62 | # 1. Accepts an input json file. 63 | # 2. Utilizes GPT-4o to generate well-formed text descriptions based on the original prompts. 64 | # 3. Sends the GPT-generated text to ChatGarment Model. 65 | # 4. Outputs the final GarmentCode sewing patterns. 66 | ./scripts/v1_5/evaluate_garment_v2_textgen.sh example_data/example_jsons/example_textgen_prompts.json 67 | ``` 68 | 69 | 70 | #### 3. Garment Editing 71 | ```Shell 72 | # Run text based generation for prompts given in the input JSON file 73 | # Detailed steps of the script: 74 | # 1. Accepts an input json file. 75 | # 2. Utilizes GPT-4o to generate well-formed editing prompts based on the original prompts. 76 | # 3. Sends the GPT-generated text to ChatGarment Model. 77 | # 4. Outputs the final GarmentCode sewing patterns. 78 | ./scripts/v1_5/evaluate_garment_v2_demo_edit.sh example_data/example_jsons/example_edit_prompts.json 79 | ``` 80 | 81 | #### 4. Multi-turn conversations. 82 | (Coming Soon) 83 | 84 | 85 | ## After Inference 86 | 87 | #### 1. Generate 3D Garments Based on ChatGarment Output 88 | After inference, ChatGarment outputs 2D sewing patterns and JSON configurations in the specified ``$(OUTPUT_DIR)``. The 2D patterns can then be stitched together to generate the corresponding 3D garments using the following code: 89 | 90 | ```Shell 91 | # Run garment stitching to get draped 3D garments 92 | # For example, $(OUTPUT_DIR) = runs/try_7b_lr1e_4_v3_garmentcontrol_4h100_v4_final_textgen_exampleimg 93 | python run_garmentcode_sim.py --all_paths_json $(OUTPUT_DIR) 94 | ``` 95 | 96 | #### 2. (Optional) Postprocessing for More Accurate Sizes 97 | ChatGarment may occasionally produce garments with incorrect lengths or widths from input images. To alleviate this, we provide a postprocessing method that refines garment sizes. Detailed instructions are available in ``docs/postprocess.md``. 98 | 99 | 100 | 101 | ## Citation 102 | ```bibtex 103 | @article{bian2024chatgarment, 104 | title={ChatGarment: Garment Estimation, Generation and Editing via Large Language Models}, 105 | author={Bian, Siyuan and Xu, Chenghao and Xiu, Yuliang and Grigorev, Artur and Liu, Zhen and Lu, Cewu and Black, Michael J and Feng, Yao}, 106 | journal={arXiv preprint arXiv:2412.17811}, 107 | year={2024} 108 | } 109 | ``` 110 | 111 | ## Acknowledgments 112 | This repository is built extensively on top of [LLaVA](https://github.com/haotian-liu/LLaVA) and [LISA](https://github.com/dvlab-research/LISA). 113 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | 7 | python_version: "3.11" 8 | 9 | python_packages: 10 | - "torch==2.0.1" 11 | - "accelerate==0.21.0" 12 | - "bitsandbytes==0.41.0" 13 | - "deepspeed==0.9.5" 14 | - "einops-exts==0.0.4" 15 | - "einops==0.6.1" 16 | - "gradio==3.35.2" 17 | - "gradio_client==0.2.9" 18 | - "httpx==0.24.0" 19 | - "markdown2==2.4.10" 20 | - "numpy==1.26.0" 21 | - "peft==0.4.0" 22 | - "scikit-learn==1.2.2" 23 | - "sentencepiece==0.1.99" 24 | - "shortuuid==1.0.11" 25 | - "timm==0.6.13" 26 | - "tokenizers==0.13.3" 27 | - "torch==2.0.1" 28 | - "torchvision==0.15.2" 29 | - "transformers==4.31.0" 30 | - "wandb==0.15.12" 31 | - "wavedrom==2.0.3.post3" 32 | - "Pygments==2.16.1" 33 | run: 34 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget 35 | 36 | # predict.py defines how predictions are run on your model 37 | predict: "predict.py:Predictor" 38 | -------------------------------------------------------------------------------- /docs/Installation.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | #### 1. Clone this repository 4 | ```bash 5 | git clone git@github.com:biansy000/ChatGarment.git 6 | cd ChatGarment 7 | ``` 8 | 9 | #### 2. Install Dependencies 10 | If you are not using Linux, see instructions for [macOS](https://github.com/haotian-liu/LLaVA/blob/main/docs/macOS.md) and [Windows](https://github.com/haotian-liu/LLaVA/blob/main/docs/Windows.md). 11 | 12 | ```Shell 13 | conda create -n chatgarment python=3.10 -y 14 | conda activate chatgarment 15 | pip install --upgrade pip # enable PEP 660 support 16 | pip install -e ".[train]" 17 | pip install flash-attn --no-build-isolation 18 | ``` 19 | 20 | #### 3. Install [GarmentCodeRC](https://github.com/biansy000/GarmentCodeRC) 21 | Follow installation instructions in its repository. 22 | 23 | 24 | #### 4. Download Pretrained Weights 25 | Put the [Pretrained weights](https://sjtueducn-my.sharepoint.com/:u:/g/personal/biansiyuan_sjtu_edu_cn/EQayoB8ie7ZIsFrjLWdBASQBFexZHXcGjrS6ghgGCjIMzw?e=o60Y65) to ``checkpoints/try_7b_lr1e_4_v3_garmentcontrol_4h100_v4_final/pytorch_model.bin``. 26 | 27 | #### 5. Update Paths in Code 28 | Modify the following lines in relevant Python files: 29 | ```Python 30 | sys.path.insert(1, '/is/cluster/fast/sbian/github/chatgarment_private') # path of the current ChatGarment repo 31 | sys.path.insert(1, '/is/cluster/fast/sbian/github/GarmentCodeV2/') # path of GarmentCodeRC repo 32 | ``` 33 | Replace with their actual local paths. 34 | 35 | #### 6. Add Soft Link 36 | Add the softlink of ``assets`` folder in ``GarmentCodeRC`` repo: 37 | ```Shell 38 | ln -s path_to_garmentcode_assets assets 39 | ``` 40 | -------------------------------------------------------------------------------- /docs/all_float_paths.json: -------------------------------------------------------------------------------- 1 | ["design.waistband.waist", "design.waistband.width", "design.shirt.length", "design.shirt.width", "design.shirt.flare", "design.collar.width", "design.collar.fc_depth", "design.collar.bc_depth", "design.collar.f_bezier_x", "design.collar.f_bezier_y", "design.collar.b_bezier_x", "design.collar.b_bezier_y", "design.collar.component.hood_depth", "design.collar.component.hood_length", "design.sleeve.length", "design.sleeve.connecting_width", "design.sleeve.end_width", "design.sleeve.opening_dir_mix", "design.sleeve.standing_shoulder_len", "design.sleeve.connect_ruffle", "design.sleeve.smoothing_coeff", "design.sleeve.cuff.top_ruffle", "design.sleeve.cuff.cuff_len", "design.sleeve.cuff.skirt_fraction", "design.sleeve.cuff.skirt_flare", "design.sleeve.cuff.skirt_ruffle", "design.left.shirt.width", "design.left.shirt.flare", "design.left.collar.width", "design.left.collar.f_bezier_x", "design.left.collar.f_bezier_y", "design.left.collar.b_bezier_x", "design.left.collar.b_bezier_y", "design.left.sleeve.length", "design.left.sleeve.connecting_width", "design.left.sleeve.end_width", "design.left.sleeve.opening_dir_mix", "design.left.sleeve.standing_shoulder_len", "design.left.sleeve.connect_ruffle", "design.left.sleeve.smoothing_coeff", "design.left.sleeve.cuff.top_ruffle", "design.left.sleeve.cuff.cuff_len", "design.left.sleeve.cuff.skirt_fraction", "design.left.sleeve.cuff.skirt_flare", "design.left.sleeve.cuff.skirt_ruffle", "design.skirt.length", "design.skirt.rise", "design.skirt.ruffle", "design.skirt.bottom_cut", "design.flare-skirt.length", "design.flare-skirt.rise", "design.flare-skirt.suns", "design.flare-skirt.asymm.front_length", "design.flare-skirt.cut.depth", "design.flare-skirt.cut.width", "design.flare-skirt.cut.place", "design.pencil-skirt.length", "design.pencil-skirt.rise", "design.pencil-skirt.flare", "design.pencil-skirt.front_slit", "design.pencil-skirt.back_slit", "design.pencil-skirt.left_slit", "design.pencil-skirt.right_slit", "design.levels-skirt.level_ruffle", "design.levels-skirt.length", "design.levels-skirt.rise", "design.levels-skirt.base_length_frac", "design.pants.length", "design.pants.width", "design.pants.flare", "design.pants.rise", "design.pants.cuff.top_ruffle", "design.pants.cuff.cuff_len", "design.pants.cuff.skirt_fraction", "design.pants.cuff.skirt_flare", "design.pants.cuff.skirt_ruffle"] -------------------------------------------------------------------------------- /docs/images/black_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/docs/images/black_img.jpg -------------------------------------------------------------------------------- /docs/images/img_recon.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/docs/images/img_recon.gif -------------------------------------------------------------------------------- /docs/images/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/docs/images/teaser.png -------------------------------------------------------------------------------- /docs/images/text_generation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/docs/images/text_generation.png -------------------------------------------------------------------------------- /docs/images/video_edit.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/docs/images/video_edit.gif -------------------------------------------------------------------------------- /docs/images/video_edit_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/docs/images/video_edit_2.gif -------------------------------------------------------------------------------- /docs/postprocess.md: -------------------------------------------------------------------------------- 1 | # Postprocessing after ChatGarment Inference 2 | 3 | ChatGarment may occasionally produce garments with incorrect lengths or widths from input images. To alleviate this, we provide a postprocessing method that refines garment sizes using a finite-difference-based approach. This process adjusts the garment length and width to better match the segmentation mask predicted by SAM (Segment Anything Model). 4 | 5 | Assume that the input images are placed in the folder ``example_data/example_imgs``. 6 | 7 | ### Step 1. Garment Segmentation with Grounding-SAM 8 | Install [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) and [segment-anything](https://github.com/facebookresearch/segment-anything) for segmentation. You can follow the installation instructions provided in [PuzzleAvatar](https://github.com/YuliangXiu/PuzzleAvatar/blob/main/scripts/install_dino_sam.sh) 9 | 10 | Run the segmentation script: 11 | ```bash 12 | python scripts/postprocess/grounding_sam.py --in_dir example_data/example_imgs --out_dir runs/example_eva_SAM 13 | ``` 14 | 15 | ### Step 2. Human Pose and Shape Estimation with TokenHMR 16 | Install [TokenHMR](https://github.com/saidwivedi/TokenHMR) for human pose estimation. Navigate to the TokenHMR directory: 17 | ```bash 18 | cd PATH_TO_TOKENHMR 19 | ``` 20 | 21 | Next, modify ``demo.py`` by inserting the following code after [this line](https://github.com/saidwivedi/TokenHMR/blob/198645f7784a27a4df0eac32478b1e7bc3e13574/tokenhmr/demo.py#L116): 22 | ```python 23 | out_saved = out.copy() 24 | out_saved['pred_cam_t_full'] = pred_cam_t_full[n] 25 | out_saved['scaled_focal_length'] = scaled_focal_length 26 | for k, v in out_saved['pred_smpl_params'].items(): 27 | if isinstance(v, torch.Tensor): 28 | out_saved['pred_smpl_params'][k] = v.detach().cpu().numpy() 29 | with open(os.path.join(args.out_folder, f'{img_fn}_{person_id}.pkl'), 'wb') as f: 30 | pickle.dump(out_saved, f) 31 | ``` 32 | 33 | Then, run TokenHMR with the following command: 34 | ```bash 35 | python tokenhmr/demo.py \ 36 | --img_folder {PATH_TO_CCHATGARMENT}/runs/example_eva_SAM/imgs_upsampled \ 37 | --batch_size=1 \ 38 | --full_frame \ 39 | --checkpoint data/checkpoints/tokenhmr_model_latest.ckpt \ 40 | --model_config data/checkpoints/model_config.yaml \ 41 | --out_folder {PATH_TO_CCHATGARMENT}/runs/example_eva_SAM/tokenhmr_output 42 | ``` 43 | 44 | ### Step 3. Install Extra Packages 45 | * Pytorch3D: Follow the official [installation guide](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md). 46 | * Chumpy: Install with pip: ``pip install chumpy``. 47 | Then, comment out the following line in ``chumpy/__init__.py``: 48 | 49 | ```python 50 | from numpy import bool, int, float, complex, object, unicode, str, nan, inf 51 | ``` 52 | 53 | 54 | ### Step 4. Run the Postprocessing Script 55 | Assume you ChatGarment inference results in ``runs/try_7b_lr1e_4_v3_garmentcontrol_4h100_v4_final_eva/vis_new/``. Download the required [extra-data](https://drive.google.com/file/d/1QXezA3J6uXqWHGATmcw3jaYxRXY2Ctte/view?usp=sharing) and extract it to ``checkpoints/extra_data``. Now run the postprocessing script. For example, to process the image:``1aee14a8c7b4d56b4e8b6ddd575d1f561a72fdc75c43a4b6926f1655152193c6.png``, use: 56 | ```bash 57 | python scripts/postprocess/postprocess.py --imgname 1aee14a8c7b4d56b4e8b6ddd575d1f561a72fdc75c43a4b6926f1655152193c6 \ 58 | --img_dir runs/example_eva_SAM/imgs_upsampled \ 59 | --inp_pose_params_dir runs/example_eva_SAM/tokenhmr_output \ 60 | --garmentcode_dir runs/try_7b_lr1e_4_v3_garmentcontrol_4h100_v4_final/example_imgs_img_recon/vis_new/ \ 61 | --saved_dir runs/example_eva_SAM/postprocess \ 62 | --garment_seg_dir runs/example_eva_SAM/mask/ 63 | ``` 64 | -------------------------------------------------------------------------------- /docs/prompts/detailed_textbased_description.txt: -------------------------------------------------------------------------------- 1 | I will provide some text descriptions of a [TYPE]. Describe the garment based on these texts. 2 | 3 | You should generate a LIST of THREE strings. 4 | 5 | In the first string, describe the garment type (If THE SUBJECT HAS A NAME, INCLUDE ITS NAME FIRST!); 6 | 7 | Example phrases for the first string: "hood", "T-shirt", "jacket", "tuxedo", etc. 8 | 9 | 10 | In the second string, describe the structures of the garment (DO NOT INCLUDE ANY INFO ABOUT THE HUMAN MODEL AND THE COLOR OF THE GARMENT) in the format of a dict. You should include the most common structures of a [TYPE] even if they are not specified in the text descriptions. 11 | 12 | Select the keys from the following list: 13 | ['width', 'length', 'sleeves', 'pant legs', 'waist', 'dress', 'skirt hems', 'collar', 'hood', 'waist', ... ] 14 | 15 | In the value of the dict, please use several different short phrases in a list with the following tips: 16 | 17 | Describe the width of the garment: wide, normal, narrow, etc. 18 | Describe the length of the garment: long, normal, short, etc. 19 | Describe the length and width of the sleeves: long, normal, short, tight, loose sleeveless, etc. 20 | Describe the detailed struture of the sleeves. Example: "asymmetrical sleeves", "straight sleeves", "puff sleeves", "three-quater sleeves", "accordion sleeves", etc. 21 | Describe the length and width of the legs of trousers: long, normal, short, tight, loose legs, etc. 22 | Describe the detailed struture of the pant legs. Example: "asymmetrical legs", "straight legs", "flared legs", "cropped legs", "cuffed legs", etc. 23 | Describe the length and width of the dress: long dress, normal dress, short dress, tight dress, loose dress, etc. 24 | Describe the detailed struture of the skirt hems. Example: "straight hem", "A-line hem", "pleated hem", "pencil hem", "slit hem", etc. 25 | Describe the detailed struture of the neck or collar. Example: "crew neck", "V-neck", "turtle neck", "collarless", etc. 26 | Describe the detailed struture of the hood. Example: "normal hood", "cape hood", "cowl hood", etc. 27 | 28 | An example of the dict description for a T-shirt is: 29 | { 30 | 'width': ['wide'], 31 | 'length': ['normal'], 32 | 'sleeves': ['elbow-length sleeves', 'tight sleeves', 'accordion sleeves'], 33 | 'collar': ['crew neck'], 34 | 'hood': ['no hood'] 35 | } 36 | 37 | An example of the dict description for a skirt is: 38 | { 39 | 'width': ['wide'], 40 | 'length': ['knee-length'], 41 | 'waist': ['high waist'], 42 | 'skirt hems': ['pencil hem', 'pleated hem'] 43 | } 44 | 45 | In the third string, describe the extra detailed structures of the garment (DO NOT INCLUDE ANY INFO ABOUT THE HUMAN MODEL AND THE COLOR OR PATTERN OF THE GARMENT) that are missing in the second string using several different short phrases split by ','. Example phrases for the third string: "pleated skirt", "high-waist", "zipper closure", "frayed hem", "mid-rise waist", etc. If there is no extra structures, return an empty string. 46 | 47 | Please strictly avoid mentioning color, texture, and material. 48 | 49 | Return the results in the following format: [garment type, garment geometric features, extra features]. Only return the JSON List in the above format without adding explanations. 50 | 51 | The text description is: [DESCRIPTION] -------------------------------------------------------------------------------- /docs/prompts/gpt4v_prompt_garment_sam.txt: -------------------------------------------------------------------------------- 1 | Analyze the provided images, each featuring an individual. Identify and describe the individual's garments like shirts, outer coats, hats, pants, shoes, dresses, skirts, scarves, etc. Return the results in a dictionary format as follows: {"shirt": shirt description, "dress": dress description, "skirt": skirt description, "pants": pants description, "shoes": shoes description, "outer coat": outer coat description...}. The "description" should be one or two noun/adj words that describe the topological or geometric features, such as length (short/long), shape or style, without referencing color or texture pattern. Exclude accessories like belts, watch, badges, and etc. Remove the key if the garment does not appear, or the value string is empty (""), only keep the visible garments, do not describe colors, and ensure no garment is described within the description of another (e.g., {"pants": "long dress"}). All strings should be enclosed in double quotes. The response should only contain the dictionary, without additional sentences, explanations, or markdowns. -------------------------------------------------------------------------------- /docs/prompts/prompt_garment_editing.txt: -------------------------------------------------------------------------------- 1 | I will provide text prompts to edit some specific garment parts of the [TYPE]. Based on the prompt and the image of the original garment, generate a structured garment part description in a Python dict format. 2 | 3 | The possible editable parts are: ['waistband', 'shirt main body panel', 'collar and neckline', 'sleeves', 'sleeve_cuff', 'skirt', 'pants', 'pant_cuff', ...] 4 | 5 | Text Prompt: 6 | [DESCRIPTION] 7 | 8 | Output Format: 9 | Only return a JSON dict in the format: ``{part-name-1: [geometry feature 1, geometry feature 2, geometry feature 3, ...], part-name-2: [...], ...}``, where ``part-name-1`` and ``part-name-2`` are names of the edited garment parts, and ``[geometry feature 1, geometry feature 2, geometry feature 3, ...]`` are features of the garment part After editing. Please ONLY focus on the geometric feature. Strictly avoid mentioning color, texture, seams, and material. Exclude garment parts that remain unchanged. -------------------------------------------------------------------------------- /docs/prompts/prompt_garment_part_inference.txt: -------------------------------------------------------------------------------- 1 | I will provide an image of human models wearing the [GARMENT], and please focus on the [PART] on their [GARMENT]. 2 | 3 | Please describe the geometry features of all the [PART] on the [GARMENT]. Please only describe geometries and structures of [PART]. Strictly avoid mentioning [DONOT] or other garment parts. Do not describe features not shared by all garment, and strictly avoid mentioning color, texture, seams, and material. 4 | 5 | Return a Json LIST of several phrases, each describing a geometric feature of the [PART], in the Json list format: [geometry feature 1, geometry feature 2, geometry feature 3, ...]. -------------------------------------------------------------------------------- /docs/prompts/smplified_image_description.txt: -------------------------------------------------------------------------------- 1 | I will provide an image of a human model wearing several garments. Describe the outer layer garments the model is wearing. In the image, the model may wear one upper garment and one lower garment, or the model may wear a single wholebody garment. Avoid describing extra accessories such as the scarves, socks, watch, badges, and etc. 2 | 3 | For each garment, you should generate TWO strings. 4 | 5 | In the first string, describe the garment type (If THE SUBJECT HAS A NAME, INCLUDE ITS NAME FIRST!); 6 | 7 | Example phrases for the first string: "hood", "T-shirt", "jacket", "tuxedo", etc. 8 | 9 | 10 | In the second string, describe the overall global geometric features of the garment (DO NOT INCLUDE ANY INFO ABOUT THE HUMAN MODEL AND THE COLOR INFO OF THE GARMENT) using several different short phrases split by ',' with the following tips: 11 | 12 | Example rules: 13 | Describe the length of the sleeves: long, normal, short, sleeveless, etc. 14 | Describe if it has a hood: with a hood, etc. 15 | Describe the length of the dress: long, normal, short, etc. 16 | Describe the width of the garment: wide, normal, narrow, etc. 17 | Describe the length of the garment: long, normal, short, etc. 18 | Describe the length of the legs of trousers: long, normal, short, etc. 19 | 20 | Please follow the example rules above (not limited to these examples) to describe the geometric features of the garment. 21 | 22 | Example phrases for the second string: "long sleeves", "wide garment", "with a hood", "deep collar", "sleeveless"... 23 | 24 | 25 | Please strictly avoid mentioning color, texture, and material. 26 | 27 | In the image, if the model is wearing one upper garment and one lower garment, return the results in the following format: {"upper garment": [upper garment type, upper garment geometric features], "lower garment": [lower garment type, lower garment geometric features]}. Otherwise, the model is wearing a single wholebody garment , return the results in the following format: {"wholebody garment": [wholebody garment type, wholebody garment geometric features]}. Return only the JSON dictionary in the above format with a length of 1 or 2. -------------------------------------------------------------------------------- /example_data/example_imgs/1aee14a8c7b4d56b4e8b6ddd575d1f561a72fdc75c43a4b6926f1655152193c6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/1aee14a8c7b4d56b4e8b6ddd575d1f561a72fdc75c43a4b6926f1655152193c6.png -------------------------------------------------------------------------------- /example_data/example_imgs/1dde6afed43187fe927089a615e3f744724ef3defdf3f2ae4a6cede5ad71dcea.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/1dde6afed43187fe927089a615e3f744724ef3defdf3f2ae4a6cede5ad71dcea.png -------------------------------------------------------------------------------- /example_data/example_imgs/62bb809fc2dcd50409cb36163a0eb222f9aa1af0f256a3233b67b3ed4081dc71.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/62bb809fc2dcd50409cb36163a0eb222f9aa1af0f256a3233b67b3ed4081dc71.png -------------------------------------------------------------------------------- /example_data/example_imgs/6fe14e1f646513ee93714fbe8026a84c6a2897be4df2f3c936cb2be8dd2d1762.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/6fe14e1f646513ee93714fbe8026a84c6a2897be4df2f3c936cb2be8dd2d1762.png -------------------------------------------------------------------------------- /example_data/example_imgs/72b086429d2dfe2a8de6f4403a024b2bb17446021c9e8f9ebacfc7a990ac8434.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/72b086429d2dfe2a8de6f4403a024b2bb17446021c9e8f9ebacfc7a990ac8434.png -------------------------------------------------------------------------------- /example_data/example_imgs/80141ce740f489f1d2f57a03f32c7577a28b62a6ac790a0d9ed8a18d961c2918.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/80141ce740f489f1d2f57a03f32c7577a28b62a6ac790a0d9ed8a18d961c2918.png -------------------------------------------------------------------------------- /example_data/example_imgs/8e3c458da20c290c216813ec07f1a2e8f9cfb4ee7e412a783a238ec353b346a0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/8e3c458da20c290c216813ec07f1a2e8f9cfb4ee7e412a783a238ec353b346a0.png -------------------------------------------------------------------------------- /example_data/example_imgs/c2b582eb318455abaf8ed8e3126c1b423ade2704d810f7cd24428febda5632fa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/c2b582eb318455abaf8ed8e3126c1b423ade2704d810f7cd24428febda5632fa.png -------------------------------------------------------------------------------- /example_data/example_imgs/d77c6f5d4856831878eadb7fe3c8b180bfa9e9ad4a14936ac10a1697bb3c054f.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/d77c6f5d4856831878eadb7fe3c8b180bfa9e9ad4a14936ac10a1697bb3c054f.png -------------------------------------------------------------------------------- /example_data/example_imgs/e918651cc154a7570e47d8b8f6c0f0f93cfbb7d5129103a1bacd8299ba945f91.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_imgs/e918651cc154a7570e47d8b8f6c0f0f93cfbb7d5129103a1bacd8299ba945f91.png -------------------------------------------------------------------------------- /example_data/example_jsons/example_edit_prompts.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "001", 4 | "garmenttype": "upperbody garment", 5 | "image": "example_data/example_sewing_patterns/example_shirt/valid_garment_upper_render_front.png", 6 | "prompt": "Adjust the neckline to a classic crew neck and change to a sleeveless shirt", 7 | "json_path": "example_data/example_sewing_patterns/example_shirt/design.yaml" 8 | } 9 | ] -------------------------------------------------------------------------------- /example_data/example_jsons/example_textgen_prompts.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": "001", 4 | "upperbody garment": { 5 | "name": "shirt", 6 | "text": "A crew-neck short-sleeve shirt" 7 | }, 8 | "lowerbody garment": { 9 | "name": "pants", 10 | "text": "A pair of long, loose pants" 11 | } 12 | }, 13 | { 14 | "id": "002", 15 | "upperbody garment": { 16 | "name": "shirt", 17 | "text": "A V-neck sleeveless shirt" 18 | }, 19 | "lowerbody garment": { 20 | "name": "pants", 21 | "text": "A pair of shorts" 22 | } 23 | }, 24 | { 25 | "id": "003", 26 | "wholebody garment": { 27 | "name": "dress", 28 | "text": "A long-sleeve dress" 29 | } 30 | } 31 | ] -------------------------------------------------------------------------------- /example_data/example_sewing_patterns/example_shirt/valid_garment_upper_render_front.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/example_data/example_sewing_patterns/example_shirt/valid_garment_upper_render_front.png -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | NUM_SECONDS_TO_SLEEP = 3 11 | 12 | @ray.remote(num_cpus=4) 13 | def get_eval(content: str, max_tokens: int): 14 | while True: 15 | try: 16 | response = openai.ChatCompletion.create( 17 | model='gpt-4', 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except openai.error.RateLimitError: 30 | pass 31 | except Exception as e: 32 | print(e) 33 | time.sleep(NUM_SECONDS_TO_SLEEP) 34 | 35 | print('success!') 36 | return response['choices'][0]['message']['content'] 37 | 38 | 39 | def parse_score(review): 40 | try: 41 | score_pair = review.split('\n')[0] 42 | score_pair = score_pair.replace(',', ' ') 43 | sp = score_pair.split(' ') 44 | if len(sp) == 2: 45 | return [float(sp[0]), float(sp[1])] 46 | else: 47 | print('error', review) 48 | return [-1, -1] 49 | except Exception as e: 50 | print(e) 51 | print('error', review) 52 | return [-1, -1] 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 57 | parser.add_argument('-q', '--question') 58 | # parser.add_argument('-a', '--answer') 59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 60 | parser.add_argument('-r', '--rule') 61 | parser.add_argument('-o', '--output') 62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 63 | args = parser.parse_args() 64 | 65 | ray.init() 66 | 67 | f_q = open(os.path.expanduser(args.question)) 68 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 69 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 71 | 72 | review_file = open(f'{args.output}', 'w') 73 | 74 | js_list = [] 75 | handles = [] 76 | idx = 0 77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 78 | # if idx == 1: 79 | # break 80 | 81 | ques = json.loads(ques_js) 82 | ans1 = json.loads(ans1_js) 83 | ans2 = json.loads(ans2_js) 84 | 85 | category = json.loads(ques_js)['category'] 86 | if category in rule_dict: 87 | rule = rule_dict[category] 88 | else: 89 | rule = rule_dict['default'] 90 | prompt = rule['prompt'] 91 | role = rule['role'] 92 | content = (f'[Question]\n{ques["text"]}\n\n' 93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 95 | f'[System]\n{prompt}\n\n') 96 | js_list.append({ 97 | 'id': idx+1, 98 | 'question_id': ques['question_id'], 99 | 'answer1_id': ans1['answer_id'], 100 | 'answer2_id': ans2['answer_id'], 101 | 'category': category}) 102 | idx += 1 103 | handles.append(get_eval.remote(content, args.max_tokens)) 104 | # To avoid the rate limit set by OpenAI 105 | time.sleep(NUM_SECONDS_TO_SLEEP) 106 | 107 | reviews = ray.get(handles) 108 | for idx, review in enumerate(reviews): 109 | scores = parse_score(review) 110 | js_list[idx]['content'] = review 111 | js_list[idx]['tuple'] = scores 112 | review_file.write(json.dumps(js_list[idx]) + '\n') 113 | review_file.close() 114 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | 86 | if isinstance(inst['caption'], list): 87 | cap_str = '\n'.join(inst['caption']) 88 | else: 89 | cap_str = inst['caption'] 90 | 91 | category = 'llava_bench_' + json.loads(ques_js)['category'] 92 | if category in rule_dict: 93 | rule = rule_dict[category] 94 | else: 95 | assert False, f"Visual QA category not found in rule file: {category}." 96 | prompt = rule['prompt'] 97 | role = rule['role'] 98 | content = (f'[Context]\n{cap_str}\n\n' 99 | f'[Question]\n{ques["text"]}\n\n' 100 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 101 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 102 | f'[System]\n{prompt}\n\n') 103 | cur_js = { 104 | 'id': idx+1, 105 | 'question_id': ques['question_id'], 106 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 107 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 108 | 'category': category 109 | } 110 | if idx >= len(cur_reviews): 111 | review = get_eval(content, args.max_tokens) 112 | scores = parse_score(review) 113 | cur_js['content'] = review 114 | cur_js['tuple'] = scores 115 | review_file.write(json.dumps(cur_js) + '\n') 116 | review_file.flush() 117 | else: 118 | print(f'Skipping {idx} as we already have it.') 119 | idx += 1 120 | print(idx) 121 | review_file.close() 122 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | cap_str = '\n'.join(inst['captions']) 86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 87 | 88 | category = json.loads(ques_js)['category'] 89 | if category in rule_dict: 90 | rule = rule_dict[category] 91 | else: 92 | assert False, f"Visual QA category not found in rule file: {category}." 93 | prompt = rule['prompt'] 94 | role = rule['role'] 95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 96 | f'[Question]\n{ques["text"]}\n\n' 97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 99 | f'[System]\n{prompt}\n\n') 100 | cur_js = { 101 | 'id': idx+1, 102 | 'question_id': ques['question_id'], 103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 105 | 'category': category 106 | } 107 | if idx >= len(cur_reviews): 108 | review = get_eval(content, args.max_tokens) 109 | scores = parse_score(review) 110 | cur_js['content'] = review 111 | cur_js['tuple'] = scores 112 | review_file.write(json.dumps(cur_js) + '\n') 113 | review_file.flush() 114 | else: 115 | print(f'Skipping {idx} as we already have it.') 116 | idx += 1 117 | print(idx) 118 | review_file.close() 119 | -------------------------------------------------------------------------------- /llava/eval/eval_pope.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | def eval_pope(answers, label_file): 6 | label_list = [json.loads(q)['label'] for q in open(label_file, 'r')] 7 | 8 | for answer in answers: 9 | text = answer['text'] 10 | 11 | # Only keep the first sentence 12 | if text.find('.') != -1: 13 | text = text.split('.')[0] 14 | 15 | text = text.replace(',', '') 16 | words = text.split(' ') 17 | if 'No' in words or 'not' in words or 'no' in words: 18 | answer['text'] = 'no' 19 | else: 20 | answer['text'] = 'yes' 21 | 22 | for i in range(len(label_list)): 23 | if label_list[i] == 'no': 24 | label_list[i] = 0 25 | else: 26 | label_list[i] = 1 27 | 28 | pred_list = [] 29 | for answer in answers: 30 | if answer['text'] == 'no': 31 | pred_list.append(0) 32 | else: 33 | pred_list.append(1) 34 | 35 | pos = 1 36 | neg = 0 37 | yes_ratio = pred_list.count(1) / len(pred_list) 38 | 39 | TP, TN, FP, FN = 0, 0, 0, 0 40 | for pred, label in zip(pred_list, label_list): 41 | if pred == pos and label == pos: 42 | TP += 1 43 | elif pred == pos and label == neg: 44 | FP += 1 45 | elif pred == neg and label == neg: 46 | TN += 1 47 | elif pred == neg and label == pos: 48 | FN += 1 49 | 50 | print('TP\tFP\tTN\tFN\t') 51 | print('{}\t{}\t{}\t{}'.format(TP, FP, TN, FN)) 52 | 53 | precision = float(TP) / float(TP + FP) 54 | recall = float(TP) / float(TP + FN) 55 | f1 = 2*precision*recall / (precision + recall) 56 | acc = (TP + TN) / (TP + TN + FP + FN) 57 | print('Accuracy: {}'.format(acc)) 58 | print('Precision: {}'.format(precision)) 59 | print('Recall: {}'.format(recall)) 60 | print('F1 score: {}'.format(f1)) 61 | print('Yes ratio: {}'.format(yes_ratio)) 62 | print('%.3f, %.3f, %.3f, %.3f, %.3f' % (f1, acc, precision, recall, yes_ratio) ) 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument("--annotation-dir", type=str) 67 | parser.add_argument("--question-file", type=str) 68 | parser.add_argument("--result-file", type=str) 69 | args = parser.parse_args() 70 | 71 | questions = [json.loads(line) for line in open(args.question_file)] 72 | questions = {question['question_id']: question for question in questions} 73 | answers = [json.loads(q) for q in open(args.result_file)] 74 | for file in os.listdir(args.annotation_dir): 75 | assert file.startswith('coco_pope_') 76 | assert file.endswith('.json') 77 | category = file[10:-5] 78 | cur_answers = [x for x in answers if questions[x['question_id']]['category'] == category] 79 | print('Category: {}, # samples: {}'.format(category, len(cur_answers))) 80 | eval_pope(cur_answers, os.path.join(args.annotation_dir, file)) 81 | print("====================================") 82 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base-dir', type=str) 11 | parser.add_argument('--result-file', type=str) 12 | parser.add_argument('--output-file', type=str) 13 | parser.add_argument('--output-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return -1 36 | return random.choice(range(len(choices))) 37 | 38 | 39 | if __name__ == "__main__": 40 | args = get_args() 41 | 42 | base_dir = args.base_dir 43 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 44 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 45 | predictions = [json.loads(line) for line in open(args.result_file)] 46 | predictions = {pred['question_id']: pred for pred in predictions} 47 | split_problems = {idx: problems[idx] for idx in split_indices} 48 | 49 | results = {'correct': [], 'incorrect': []} 50 | sqa_results = {} 51 | sqa_results['acc'] = None 52 | sqa_results['correct'] = None 53 | sqa_results['count'] = None 54 | sqa_results['results'] = {} 55 | sqa_results['outputs'] = {} 56 | 57 | for prob_id, prob in split_problems.items(): 58 | if prob_id not in predictions: 59 | pred = {'text': 'FAILED', 'prompt': 'Unknown'} 60 | pred_text = 'FAILED' 61 | else: 62 | pred = predictions[prob_id] 63 | pred_text = pred['text'] 64 | 65 | if pred_text in args.options: 66 | answer = pred_text 67 | elif len(pred_text) >= 3 and pred_text[0] in args.options and pred_text[1:3] == ". ": 68 | answer = pred_text[0] 69 | else: 70 | pattern = re.compile(r'The answer is ([A-Z]).') 71 | res = pattern.findall(pred_text) 72 | if len(res) == 1: 73 | answer = res[0] # 'A', 'B', ... 74 | else: 75 | answer = "FAILED" 76 | 77 | pred_idx = get_pred_idx(answer, prob['choices'], args.options) 78 | 79 | analysis = { 80 | 'question_id': prob_id, 81 | 'parsed_ans': answer, 82 | 'ground_truth': args.options[prob['answer']], 83 | 'question': pred['prompt'], 84 | 'pred': pred_text, 85 | 'is_multimodal': '' in pred['prompt'], 86 | } 87 | 88 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) 89 | sqa_results['outputs'][prob_id] = pred_text 90 | 91 | if pred_idx == prob['answer']: 92 | results['correct'].append(analysis) 93 | else: 94 | results['incorrect'].append(analysis) 95 | 96 | correct = len(results['correct']) 97 | total = len(results['correct']) + len(results['incorrect']) 98 | 99 | ###### IMG ###### 100 | multimodal_correct = len([x for x in results['correct'] if x['is_multimodal']]) 101 | multimodal_incorrect = len([x for x in results['incorrect'] if x['is_multimodal']]) 102 | multimodal_total = multimodal_correct + multimodal_incorrect 103 | ###### IMG ###### 104 | 105 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%, IMG-Accuracy: {multimodal_correct / multimodal_total * 100:.2f}%') 106 | 107 | sqa_results['acc'] = correct / total * 100 108 | sqa_results['correct'] = correct 109 | sqa_results['count'] = total 110 | 111 | with open(args.output_file, 'w') as f: 112 | json.dump(results, f, indent=2) 113 | with open(args.output_result, 'w') as f: 114 | json.dump(sqa_results, f, indent=2) 115 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4_requery.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--requery-result', type=str) 14 | parser.add_argument('--our-result', type=str) 15 | parser.add_argument('--output-result', type=str) 16 | parser.add_argument('--split', type=str, default='test') 17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 18 | return parser.parse_args() 19 | 20 | 21 | def convert_caps(results): 22 | fakecaps = [] 23 | for result in results: 24 | image_id = result['question_id'] 25 | caption = result['text'] 26 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 27 | return fakecaps 28 | 29 | 30 | def get_pred_idx(prediction, choices, options): 31 | """ 32 | Get the index (e.g. 2) from the prediction (e.g. 'C') 33 | """ 34 | if prediction in options[:len(choices)]: 35 | return options.index(prediction) 36 | else: 37 | return random.choice(range(len(choices))) 38 | 39 | 40 | if __name__ == "__main__": 41 | args = get_args() 42 | 43 | base_dir = args.base_dir 44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 45 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 46 | our_predictions = [json.loads(line) for line in open(args.our_result)] 47 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 48 | split_problems = {idx: problems[idx] for idx in split_indices} 49 | 50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)] 51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions} 52 | 53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 54 | 55 | results = defaultdict(lambda: 0) 56 | 57 | sqa_results = {} 58 | sqa_results['acc'] = None 59 | sqa_results['correct'] = None 60 | sqa_results['count'] = None 61 | sqa_results['results'] = {} 62 | sqa_results['outputs'] = {} 63 | 64 | for prob_id, prob in split_problems.items(): 65 | if prob_id not in our_predictions: 66 | assert False 67 | if prob_id not in gpt4_predictions: 68 | assert False 69 | our_pred = our_predictions[prob_id]['text'] 70 | gpt4_pred = gpt4_predictions[prob_id] 71 | if prob_id not in requery_predictions: 72 | results['missing_requery'] += 1 73 | requery_pred = "MISSING" 74 | else: 75 | requery_pred = requery_predictions[prob_id]['text'] 76 | 77 | pattern = re.compile(r'The answer is ([A-Z]).') 78 | our_res = pattern.findall(our_pred) 79 | if len(our_res) == 1: 80 | our_answer = our_res[0] # 'A', 'B', ... 81 | else: 82 | our_answer = "FAILED" 83 | 84 | requery_res = pattern.findall(requery_pred) 85 | if len(requery_res) == 1: 86 | requery_answer = requery_res[0] # 'A', 'B', ... 87 | else: 88 | requery_answer = "FAILED" 89 | 90 | gpt4_res = pattern.findall(gpt4_pred) 91 | if len(gpt4_res) == 1: 92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 93 | else: 94 | gpt4_answer = "FAILED" 95 | 96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) 99 | 100 | results['total'] += 1 101 | 102 | if gpt4_answer == 'FAILED': 103 | results['gpt4_failed'] += 1 104 | if gpt4_pred_idx == prob['answer']: 105 | results['gpt4_correct'] += 1 106 | if our_pred_idx == prob['answer']: 107 | results['gpt4_ourvisual_correct'] += 1 108 | elif gpt4_pred_idx == prob['answer']: 109 | results['gpt4_correct'] += 1 110 | results['gpt4_ourvisual_correct'] += 1 111 | 112 | if our_pred_idx == prob['answer']: 113 | results['our_correct'] += 1 114 | 115 | if requery_answer == 'FAILED': 116 | sqa_results['results'][prob_id] = our_pred_idx 117 | if our_pred_idx == prob['answer']: 118 | results['requery_correct'] += 1 119 | else: 120 | sqa_results['results'][prob_id] = requery_pred_idx 121 | if requery_pred_idx == prob['answer']: 122 | results['requery_correct'] += 1 123 | else: 124 | print(f""" 125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} 126 | Our ({our_answer}): {our_pred} 127 | GPT-4 ({gpt4_answer}): {gpt4_pred} 128 | Requery ({requery_answer}): {requery_pred} 129 | print("=====================================") 130 | """) 131 | 132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 133 | results['correct_upperbound'] += 1 134 | 135 | total = results['total'] 136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') 137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') 138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') 140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') 141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 142 | 143 | sqa_results['acc'] = results["requery_correct"] / total * 100 144 | sqa_results['correct'] = results["requery_correct"] 145 | sqa_results['count'] = total 146 | 147 | with open(args.output_result, 'w') as f: 148 | json.dump(sqa_results, f, indent=2) 149 | 150 | -------------------------------------------------------------------------------- /llava/eval/eval_textvqa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import json 4 | import re 5 | 6 | from llava.eval.m4c_evaluator import TextVQAAccuracyEvaluator 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--annotation-file', type=str) 12 | parser.add_argument('--result-file', type=str) 13 | parser.add_argument('--result-dir', type=str) 14 | return parser.parse_args() 15 | 16 | 17 | def prompt_processor(prompt): 18 | if prompt.startswith('OCR tokens: '): 19 | pattern = r"Question: (.*?) Short answer:" 20 | match = re.search(pattern, prompt, re.DOTALL) 21 | question = match.group(1) 22 | elif 'Reference OCR token: ' in prompt and len(prompt.split('\n')) == 3: 23 | if prompt.startswith('Reference OCR token:'): 24 | question = prompt.split('\n')[1] 25 | else: 26 | question = prompt.split('\n')[0] 27 | elif len(prompt.split('\n')) == 2: 28 | question = prompt.split('\n')[0] 29 | else: 30 | assert False 31 | 32 | return question.lower() 33 | 34 | 35 | def eval_single(annotation_file, result_file): 36 | experiment_name = os.path.splitext(os.path.basename(result_file))[0] 37 | print(experiment_name) 38 | annotations = json.load(open(annotation_file))['data'] 39 | annotations = {(annotation['image_id'], annotation['question'].lower()): annotation for annotation in annotations} 40 | results = [json.loads(line) for line in open(result_file)] 41 | 42 | pred_list = [] 43 | for result in results: 44 | annotation = annotations[(result['question_id'], prompt_processor(result['prompt']))] 45 | pred_list.append({ 46 | "pred_answer": result['text'], 47 | "gt_answers": annotation['answers'], 48 | }) 49 | 50 | evaluator = TextVQAAccuracyEvaluator() 51 | print('Samples: {}\nAccuracy: {:.2f}%\n'.format(len(pred_list), 100. * evaluator.eval_pred_list(pred_list))) 52 | 53 | 54 | if __name__ == "__main__": 55 | args = get_args() 56 | 57 | if args.result_file is not None: 58 | eval_single(args.annotation_file, args.result_file) 59 | 60 | if args.result_dir is not None: 61 | for result_file in sorted(os.listdir(args.result_dir)): 62 | if not result_file.endswith('.jsonl'): 63 | print(f'Skipping {result_file}') 64 | continue 65 | eval_single(args.annotation_file, os.path.join(args.result_dir, result_file)) 66 | -------------------------------------------------------------------------------- /llava/eval/generate_webpage_data_from_table.py: -------------------------------------------------------------------------------- 1 | """Generate json file for webpage.""" 2 | import json 3 | import os 4 | import re 5 | 6 | # models = ['llama', 'alpaca', 'gpt35', 'bard'] 7 | models = ['vicuna'] 8 | 9 | 10 | def read_jsonl(path: str, key: str=None): 11 | data = [] 12 | with open(os.path.expanduser(path)) as f: 13 | for line in f: 14 | if not line: 15 | continue 16 | data.append(json.loads(line)) 17 | if key is not None: 18 | data.sort(key=lambda x: x[key]) 19 | data = {item[key]: item for item in data} 20 | return data 21 | 22 | 23 | def trim_hanging_lines(s: str, n: int) -> str: 24 | s = s.strip() 25 | for _ in range(n): 26 | s = s.split('\n', 1)[1].strip() 27 | return s 28 | 29 | 30 | if __name__ == '__main__': 31 | questions = read_jsonl('table/question.jsonl', key='question_id') 32 | 33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') 34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') 35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') 36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') 37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') 38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') 39 | 40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') 41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') 42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') 43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') 44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') 45 | 46 | records = [] 47 | for qid in questions.keys(): 48 | r = { 49 | 'id': qid, 50 | 'category': questions[qid]['category'], 51 | 'question': questions[qid]['text'], 52 | 'answers': { 53 | # 'alpaca': alpaca_answers[qid]['text'], 54 | # 'llama': llama_answers[qid]['text'], 55 | # 'bard': bard_answers[qid]['text'], 56 | # 'gpt35': gpt35_answers[qid]['text'], 57 | 'vicuna': vicuna_answers[qid]['text'], 58 | 'ours': ours_answers[qid]['text'], 59 | }, 60 | 'evaluations': { 61 | # 'alpaca': review_alpaca[qid]['text'], 62 | # 'llama': review_llama[qid]['text'], 63 | # 'bard': review_bard[qid]['text'], 64 | 'vicuna': review_vicuna[qid]['content'], 65 | # 'gpt35': review_gpt35[qid]['text'], 66 | }, 67 | 'scores': { 68 | 'vicuna': review_vicuna[qid]['tuple'], 69 | # 'alpaca': review_alpaca[qid]['score'], 70 | # 'llama': review_llama[qid]['score'], 71 | # 'bard': review_bard[qid]['score'], 72 | # 'gpt35': review_gpt35[qid]['score'], 73 | }, 74 | } 75 | 76 | # cleanup data 77 | cleaned_evals = {} 78 | for k, v in r['evaluations'].items(): 79 | v = v.strip() 80 | lines = v.split('\n') 81 | # trim the first line if it's a pair of numbers 82 | if re.match(r'\d+[, ]+\d+', lines[0]): 83 | lines = lines[1:] 84 | v = '\n'.join(lines) 85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') 86 | 87 | r['evaluations'] = cleaned_evals 88 | records.append(r) 89 | 90 | # Reorder the records, this is optional 91 | for r in records: 92 | if r['id'] <= 20: 93 | r['id'] += 60 94 | else: 95 | r['id'] -= 20 96 | for r in records: 97 | if r['id'] <= 50: 98 | r['id'] += 10 99 | elif 50 < r['id'] <= 60: 100 | r['id'] -= 50 101 | for r in records: 102 | if r['id'] == 7: 103 | r['id'] = 1 104 | elif r['id'] < 7: 105 | r['id'] += 1 106 | 107 | records.sort(key=lambda x: x['id']) 108 | 109 | # Write to file 110 | with open('webpage/data.json', 'w') as f: 111 | json.dump({'questions': records, 'models': models}, f, indent=2) 112 | -------------------------------------------------------------------------------- /llava/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.conversation import default_conversation 10 | from llava.utils import disable_torch_init 11 | 12 | 13 | @torch.inference_mode() 14 | def eval_model(model_name, questions_file, answers_file): 15 | # Model 16 | disable_torch_init() 17 | model_name = os.path.expanduser(model_name) 18 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 19 | model = AutoModelForCausalLM.from_pretrained(model_name, 20 | torch_dtype=torch.float16).cuda() 21 | 22 | 23 | ques_file = open(os.path.expanduser(questions_file), "r") 24 | ans_file = open(os.path.expanduser(answers_file), "w") 25 | for i, line in enumerate(tqdm(ques_file)): 26 | idx = json.loads(line)["question_id"] 27 | qs = json.loads(line)["text"] 28 | cat = json.loads(line)["category"] 29 | conv = default_conversation.copy() 30 | conv.append_message(conv.roles[0], qs) 31 | prompt = conv.get_prompt() 32 | inputs = tokenizer([prompt]) 33 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 34 | output_ids = model.generate( 35 | input_ids, 36 | do_sample=True, 37 | use_cache=True, 38 | temperature=0.7, 39 | max_new_tokens=1024,) 40 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 41 | try: 42 | index = outputs.index(conv.sep, len(prompt)) 43 | except ValueError: 44 | outputs += conv.sep 45 | index = outputs.index(conv.sep, len(prompt)) 46 | 47 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 48 | ans_id = shortuuid.uuid() 49 | ans_file.write(json.dumps({"question_id": idx, 50 | "text": outputs, 51 | "answer_id": ans_id, 52 | "model_id": model_name, 53 | "metadata": {}}) + "\n") 54 | ans_file.flush() 55 | ans_file.close() 56 | 57 | if __name__ == "__main__": 58 | parser = argparse.ArgumentParser() 59 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 60 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 61 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 62 | args = parser.parse_args() 63 | 64 | eval_model(args.model_name, args.question_file, args.answers_file) 65 | -------------------------------------------------------------------------------- /llava/eval/model_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for line in tqdm(questions): 42 | idx = line["question_id"] 43 | image_file = line["image"] 44 | qs = line["text"] 45 | cur_prompt = qs 46 | if model.config.mm_use_im_start_end: 47 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 48 | else: 49 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 50 | 51 | conv = conv_templates[args.conv_mode].copy() 52 | conv.append_message(conv.roles[0], qs) 53 | conv.append_message(conv.roles[1], None) 54 | prompt = conv.get_prompt() 55 | 56 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 57 | 58 | image = Image.open(os.path.join(args.image_folder, image_file)).convert('RGB') 59 | image_tensor = process_images([image], image_processor, model.config)[0] 60 | 61 | with torch.inference_mode(): 62 | output_ids = model.generate( 63 | input_ids, 64 | images=image_tensor.unsqueeze(0).half().cuda(), 65 | image_sizes=[image.size], 66 | do_sample=True if args.temperature > 0 else False, 67 | temperature=args.temperature, 68 | top_p=args.top_p, 69 | num_beams=args.num_beams, 70 | # no_repeat_ngram_size=3, 71 | max_new_tokens=1024, 72 | use_cache=True) 73 | 74 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 75 | 76 | ans_id = shortuuid.uuid() 77 | ans_file.write(json.dumps({"question_id": idx, 78 | "prompt": cur_prompt, 79 | "text": outputs, 80 | "answer_id": ans_id, 81 | "model_id": model_name, 82 | "metadata": {}}) + "\n") 83 | ans_file.flush() 84 | ans_file.close() 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 89 | parser.add_argument("--model-base", type=str, default=None) 90 | parser.add_argument("--image-folder", type=str, default="") 91 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 92 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 93 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 94 | parser.add_argument("--num-chunks", type=int, default=1) 95 | parser.add_argument("--chunk-idx", type=int, default=0) 96 | parser.add_argument("--temperature", type=float, default=0.2) 97 | parser.add_argument("--top_p", type=float, default=None) 98 | parser.add_argument("--num_beams", type=int, default=1) 99 | args = parser.parse_args() 100 | 101 | eval_model(args) 102 | -------------------------------------------------------------------------------- /llava/eval/model_vqa_loader.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 13 | from torch.utils.data import Dataset, DataLoader 14 | 15 | from PIL import Image 16 | import math 17 | 18 | 19 | def split_list(lst, n): 20 | """Split a list into n (roughly) equal-sized chunks""" 21 | chunk_size = math.ceil(len(lst) / n) # integer division 22 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 23 | 24 | 25 | def get_chunk(lst, n, k): 26 | chunks = split_list(lst, n) 27 | return chunks[k] 28 | 29 | 30 | # Custom dataset class 31 | class CustomDataset(Dataset): 32 | def __init__(self, questions, image_folder, tokenizer, image_processor, model_config): 33 | self.questions = questions 34 | self.image_folder = image_folder 35 | self.tokenizer = tokenizer 36 | self.image_processor = image_processor 37 | self.model_config = model_config 38 | 39 | def __getitem__(self, index): 40 | line = self.questions[index] 41 | image_file = line["image"] 42 | qs = line["text"] 43 | if self.model_config.mm_use_im_start_end: 44 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 45 | else: 46 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 47 | 48 | conv = conv_templates[args.conv_mode].copy() 49 | conv.append_message(conv.roles[0], qs) 50 | conv.append_message(conv.roles[1], None) 51 | prompt = conv.get_prompt() 52 | 53 | image = Image.open(os.path.join(self.image_folder, image_file)).convert('RGB') 54 | image_tensor = process_images([image], self.image_processor, self.model_config)[0] 55 | 56 | input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt') 57 | 58 | return input_ids, image_tensor, image.size 59 | 60 | def __len__(self): 61 | return len(self.questions) 62 | 63 | 64 | def collate_fn(batch): 65 | input_ids, image_tensors, image_sizes = zip(*batch) 66 | input_ids = torch.stack(input_ids, dim=0) 67 | image_tensors = torch.stack(image_tensors, dim=0) 68 | return input_ids, image_tensors, image_sizes 69 | 70 | 71 | # DataLoader 72 | def create_data_loader(questions, image_folder, tokenizer, image_processor, model_config, batch_size=1, num_workers=4): 73 | assert batch_size == 1, "batch_size must be 1" 74 | dataset = CustomDataset(questions, image_folder, tokenizer, image_processor, model_config) 75 | data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=False, collate_fn=collate_fn) 76 | return data_loader 77 | 78 | 79 | def eval_model(args): 80 | # Model 81 | disable_torch_init() 82 | model_path = os.path.expanduser(args.model_path) 83 | model_name = get_model_name_from_path(model_path) 84 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 85 | 86 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 87 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 88 | answers_file = os.path.expanduser(args.answers_file) 89 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 90 | ans_file = open(answers_file, "w") 91 | 92 | if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode: 93 | args.conv_mode = args.conv_mode + '_mmtag' 94 | print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.') 95 | 96 | data_loader = create_data_loader(questions, args.image_folder, tokenizer, image_processor, model.config) 97 | 98 | for (input_ids, image_tensor, image_sizes), line in tqdm(zip(data_loader, questions), total=len(questions)): 99 | idx = line["question_id"] 100 | cur_prompt = line["text"] 101 | 102 | input_ids = input_ids.to(device='cuda', non_blocking=True) 103 | 104 | with torch.inference_mode(): 105 | output_ids = model.generate( 106 | input_ids, 107 | images=image_tensor.to(dtype=torch.float16, device='cuda', non_blocking=True), 108 | image_sizes=image_sizes, 109 | do_sample=True if args.temperature > 0 else False, 110 | temperature=args.temperature, 111 | top_p=args.top_p, 112 | num_beams=args.num_beams, 113 | max_new_tokens=args.max_new_tokens, 114 | use_cache=True) 115 | 116 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 117 | 118 | ans_id = shortuuid.uuid() 119 | ans_file.write(json.dumps({"question_id": idx, 120 | "prompt": cur_prompt, 121 | "text": outputs, 122 | "answer_id": ans_id, 123 | "model_id": model_name, 124 | "metadata": {}}) + "\n") 125 | # ans_file.flush() 126 | ans_file.close() 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 131 | parser.add_argument("--model-base", type=str, default=None) 132 | parser.add_argument("--image-folder", type=str, default="") 133 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 134 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 135 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 136 | parser.add_argument("--num-chunks", type=int, default=1) 137 | parser.add_argument("--chunk-idx", type=int, default=0) 138 | parser.add_argument("--temperature", type=float, default=0.2) 139 | parser.add_argument("--top_p", type=float, default=None) 140 | parser.add_argument("--num_beams", type=int, default=1) 141 | parser.add_argument("--max_new_tokens", type=int, default=128) 142 | args = parser.parse_args() 143 | 144 | eval_model(args) 145 | -------------------------------------------------------------------------------- /llava/eval/model_vqa_mmbench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | import pandas as pd 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 10 | from llava.conversation import conv_templates, SeparatorStyle 11 | from llava.model.builder import load_pretrained_model 12 | from llava.utils import disable_torch_init 13 | from llava.mm_utils import tokenizer_image_token, process_images, load_image_from_base64, get_model_name_from_path 14 | 15 | from PIL import Image 16 | import math 17 | 18 | 19 | all_options = ['A', 'B', 'C', 'D'] 20 | 21 | 22 | def split_list(lst, n): 23 | """Split a list into n (roughly) equal-sized chunks""" 24 | chunk_size = math.ceil(len(lst) / n) # integer division 25 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 26 | 27 | 28 | def get_chunk(lst, n, k): 29 | chunks = split_list(lst, n) 30 | return chunks[k] 31 | 32 | 33 | def is_none(value): 34 | if value is None: 35 | return True 36 | if type(value) is float and math.isnan(value): 37 | return True 38 | if type(value) is str and value.lower() == 'nan': 39 | return True 40 | if type(value) is str and value.lower() == 'none': 41 | return True 42 | return False 43 | 44 | def get_options(row, options): 45 | parsed_options = [] 46 | for option in options: 47 | option_value = row[option] 48 | if is_none(option_value): 49 | break 50 | parsed_options.append(option_value) 51 | return parsed_options 52 | 53 | 54 | def eval_model(args): 55 | # Model 56 | disable_torch_init() 57 | model_path = os.path.expanduser(args.model_path) 58 | model_name = get_model_name_from_path(model_path) 59 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 60 | 61 | questions = pd.read_table(os.path.expanduser(args.question_file)) 62 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 63 | answers_file = os.path.expanduser(args.answers_file) 64 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 65 | ans_file = open(answers_file, "w") 66 | 67 | if 'plain' in model_name and 'finetune' not in model_name.lower() and 'mmtag' not in args.conv_mode: 68 | args.conv_mode = args.conv_mode + '_mmtag' 69 | print(f'It seems that this is a plain model, but it is not using a mmtag prompt, auto switching to {args.conv_mode}.') 70 | 71 | for index, row in tqdm(questions.iterrows(), total=len(questions)): 72 | options = get_options(row, all_options) 73 | cur_option_char = all_options[:len(options)] 74 | 75 | if args.all_rounds: 76 | num_rounds = len(options) 77 | else: 78 | num_rounds = 1 79 | 80 | for round_idx in range(num_rounds): 81 | idx = row['index'] 82 | question = row['question'] 83 | hint = row['hint'] 84 | image = load_image_from_base64(row['image']) 85 | if not is_none(hint): 86 | question = hint + '\n' + question 87 | for option_char, option in zip(all_options[:len(options)], options): 88 | question = question + '\n' + option_char + '. ' + option 89 | qs = cur_prompt = question 90 | if model.config.mm_use_im_start_end: 91 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 92 | else: 93 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 94 | 95 | if args.single_pred_prompt: 96 | if args.lang == 'cn': 97 | qs = qs + '\n' + "请直接回答选项字母。" 98 | else: 99 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly." 100 | 101 | conv = conv_templates[args.conv_mode].copy() 102 | conv.append_message(conv.roles[0], qs) 103 | conv.append_message(conv.roles[1], None) 104 | prompt = conv.get_prompt() 105 | 106 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 107 | 108 | image_tensor = process_images([image], image_processor, model.config)[0] 109 | 110 | with torch.inference_mode(): 111 | output_ids = model.generate( 112 | input_ids, 113 | images=image_tensor.unsqueeze(0).half().cuda(), 114 | image_sizes=[image.size], 115 | do_sample=True if args.temperature > 0 else False, 116 | temperature=args.temperature, 117 | top_p=args.top_p, 118 | num_beams=args.num_beams, 119 | # no_repeat_ngram_size=3, 120 | max_new_tokens=1024, 121 | use_cache=True) 122 | 123 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 124 | 125 | ans_id = shortuuid.uuid() 126 | ans_file.write(json.dumps({"question_id": idx, 127 | "round_id": round_idx, 128 | "prompt": cur_prompt, 129 | "text": outputs, 130 | "options": options, 131 | "option_char": cur_option_char, 132 | "answer_id": ans_id, 133 | "model_id": model_name, 134 | "metadata": {}}) + "\n") 135 | ans_file.flush() 136 | 137 | # rotate options 138 | options = options[1:] + options[:1] 139 | cur_option_char = cur_option_char[1:] + cur_option_char[:1] 140 | ans_file.close() 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 145 | parser.add_argument("--model-base", type=str, default=None) 146 | parser.add_argument("--image-folder", type=str, default="") 147 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 148 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 149 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 150 | parser.add_argument("--num-chunks", type=int, default=1) 151 | parser.add_argument("--chunk-idx", type=int, default=0) 152 | parser.add_argument("--temperature", type=float, default=0.2) 153 | parser.add_argument("--top_p", type=float, default=None) 154 | parser.add_argument("--num_beams", type=int, default=1) 155 | parser.add_argument("--all-rounds", action="store_true") 156 | parser.add_argument("--single-pred-prompt", action="store_true") 157 | parser.add_argument("--lang", type=str, default="en") 158 | args = parser.parse_args() 159 | 160 | eval_model(args) 161 | -------------------------------------------------------------------------------- /llava/eval/model_vqa_science.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, process_images, get_model_name_from_path 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for i, line in enumerate(tqdm(questions)): 42 | idx = line["id"] 43 | question = line['conversations'][0] 44 | qs = question['value'].replace('', '').strip() 45 | cur_prompt = qs 46 | 47 | if 'image' in line: 48 | image_file = line["image"] 49 | image = Image.open(os.path.join(args.image_folder, image_file)) 50 | image_tensor = process_images([image], image_processor, model.config)[0] 51 | images = image_tensor.unsqueeze(0).half().cuda() 52 | image_sizes = [image.size] 53 | if getattr(model.config, 'mm_use_im_start_end', False): 54 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 55 | else: 56 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 57 | cur_prompt = '' + '\n' + cur_prompt 58 | else: 59 | images = None 60 | image_sizes = None 61 | 62 | if args.single_pred_prompt: 63 | qs = qs + '\n' + "Answer with the option's letter from the given choices directly." 64 | cur_prompt = cur_prompt + '\n' + "Answer with the option's letter from the given choices directly." 65 | 66 | conv = conv_templates[args.conv_mode].copy() 67 | conv.append_message(conv.roles[0], qs) 68 | conv.append_message(conv.roles[1], None) 69 | prompt = conv.get_prompt() 70 | 71 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 72 | 73 | with torch.inference_mode(): 74 | output_ids = model.generate( 75 | input_ids, 76 | images=images, 77 | image_sizes=image_sizes, 78 | do_sample=True if args.temperature > 0 else False, 79 | temperature=args.temperature, 80 | max_new_tokens=1024, 81 | use_cache=True, 82 | ) 83 | 84 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 85 | 86 | ans_id = shortuuid.uuid() 87 | ans_file.write(json.dumps({"question_id": idx, 88 | "prompt": cur_prompt, 89 | "text": outputs, 90 | "answer_id": ans_id, 91 | "model_id": model_name, 92 | "metadata": {}}) + "\n") 93 | ans_file.flush() 94 | ans_file.close() 95 | 96 | if __name__ == "__main__": 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 99 | parser.add_argument("--model-base", type=str, default=None) 100 | parser.add_argument("--image-folder", type=str, default="") 101 | parser.add_argument("--question-file", type=str, default="tables/question.json") 102 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 103 | parser.add_argument("--conv-mode", type=str, default="llava_v0") 104 | parser.add_argument("--num-chunks", type=int, default=1) 105 | parser.add_argument("--chunk-idx", type=int, default=0) 106 | parser.add_argument("--temperature", type=float, default=0.2) 107 | parser.add_argument("--answer-prompter", action="store_true") 108 | parser.add_argument("--single-pred-prompt", action="store_true") 109 | args = parser.parse_args() 110 | 111 | eval_model(args) 112 | -------------------------------------------------------------------------------- /llava/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import concurrent.futures 8 | 9 | import openai 10 | import tqdm 11 | import shortuuid 12 | 13 | MODEL = 'gpt-3.5-turbo' 14 | MODEL_ID = 'gpt-3.5-turbo:20230327' 15 | 16 | def get_answer(question_id: int, question: str, max_tokens: int): 17 | ans = { 18 | 'answer_id': shortuuid.uuid(), 19 | 'question_id': question_id, 20 | 'model_id': MODEL_ID, 21 | } 22 | for _ in range(3): 23 | try: 24 | response = openai.ChatCompletion.create( 25 | model=MODEL, 26 | messages=[{ 27 | 'role': 'system', 28 | 'content': 'You are a helpful assistant.' 29 | }, { 30 | 'role': 'user', 31 | 'content': question, 32 | }], 33 | max_tokens=max_tokens, 34 | ) 35 | ans['text'] = response['choices'][0]['message']['content'] 36 | return ans 37 | except Exception as e: 38 | print('[ERROR]', e) 39 | ans['text'] = '#ERROR#' 40 | time.sleep(1) 41 | return ans 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 46 | parser.add_argument('-q', '--question') 47 | parser.add_argument('-o', '--output') 48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 49 | args = parser.parse_args() 50 | 51 | questions_dict = {} 52 | with open(os.path.expanduser(args.question)) as f: 53 | for line in f: 54 | if not line: 55 | continue 56 | q = json.loads(line) 57 | questions_dict[q['question_id']] = q['text'] 58 | 59 | answers = [] 60 | 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 62 | futures = [] 63 | for qid, question in questions_dict.items(): 64 | future = executor.submit(get_answer, qid, question, args.max_tokens) 65 | futures.append(future) 66 | 67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 68 | answers.append(future.result()) 69 | 70 | answers.sort(key=lambda x: x['question_id']) 71 | 72 | with open(os.path.expanduser(args.output), 'w') as f: 73 | table = [json.dumps(ans) for ans in answers] 74 | f.write('\n'.join(table)) 75 | -------------------------------------------------------------------------------- /llava/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import ( 5 | IMAGE_TOKEN_INDEX, 6 | DEFAULT_IMAGE_TOKEN, 7 | DEFAULT_IM_START_TOKEN, 8 | DEFAULT_IM_END_TOKEN, 9 | IMAGE_PLACEHOLDER, 10 | ) 11 | from llava.conversation import conv_templates, SeparatorStyle 12 | from llava.model.builder import load_pretrained_model 13 | from llava.utils import disable_torch_init 14 | from llava.mm_utils import ( 15 | process_images, 16 | tokenizer_image_token, 17 | get_model_name_from_path, 18 | ) 19 | 20 | from PIL import Image 21 | 22 | import requests 23 | from PIL import Image 24 | from io import BytesIO 25 | import re 26 | 27 | 28 | def image_parser(args): 29 | out = args.image_file.split(args.sep) 30 | return out 31 | 32 | 33 | def load_image(image_file): 34 | if image_file.startswith("http") or image_file.startswith("https"): 35 | response = requests.get(image_file) 36 | image = Image.open(BytesIO(response.content)).convert("RGB") 37 | else: 38 | image = Image.open(image_file).convert("RGB") 39 | return image 40 | 41 | 42 | def load_images(image_files): 43 | out = [] 44 | for image_file in image_files: 45 | image = load_image(image_file) 46 | out.append(image) 47 | return out 48 | 49 | 50 | def eval_model(args): 51 | # Model 52 | disable_torch_init() 53 | 54 | model_name = get_model_name_from_path(args.model_path) 55 | tokenizer, model, image_processor, context_len = load_pretrained_model( 56 | args.model_path, args.model_base, model_name 57 | ) 58 | 59 | qs = args.query 60 | image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN 61 | if IMAGE_PLACEHOLDER in qs: 62 | if model.config.mm_use_im_start_end: 63 | qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs) 64 | else: 65 | qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs) 66 | else: 67 | if model.config.mm_use_im_start_end: 68 | qs = image_token_se + "\n" + qs 69 | else: 70 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 71 | 72 | if "llama-2" in model_name.lower(): 73 | conv_mode = "llava_llama_2" 74 | elif "mistral" in model_name.lower(): 75 | conv_mode = "mistral_instruct" 76 | elif "v1.6-34b" in model_name.lower(): 77 | conv_mode = "chatml_direct" 78 | elif "v1" in model_name.lower(): 79 | conv_mode = "llava_v1" 80 | elif "mpt" in model_name.lower(): 81 | conv_mode = "mpt" 82 | else: 83 | conv_mode = "llava_v0" 84 | 85 | if args.conv_mode is not None and conv_mode != args.conv_mode: 86 | print( 87 | "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( 88 | conv_mode, args.conv_mode, args.conv_mode 89 | ) 90 | ) 91 | else: 92 | args.conv_mode = conv_mode 93 | 94 | conv = conv_templates[args.conv_mode].copy() 95 | conv.append_message(conv.roles[0], qs) 96 | conv.append_message(conv.roles[1], None) 97 | prompt = conv.get_prompt() 98 | 99 | image_files = image_parser(args) 100 | images = load_images(image_files) 101 | image_sizes = [x.size for x in images] 102 | images_tensor = process_images( 103 | images, 104 | image_processor, 105 | model.config 106 | ).to(model.device, dtype=torch.float16) 107 | 108 | input_ids = ( 109 | tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") 110 | .unsqueeze(0) 111 | .cuda() 112 | ) 113 | 114 | with torch.inference_mode(): 115 | output_ids = model.generate( 116 | input_ids, 117 | images=images_tensor, 118 | image_sizes=image_sizes, 119 | do_sample=True if args.temperature > 0 else False, 120 | temperature=args.temperature, 121 | top_p=args.top_p, 122 | num_beams=args.num_beams, 123 | max_new_tokens=args.max_new_tokens, 124 | use_cache=True, 125 | ) 126 | 127 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 128 | print(outputs) 129 | 130 | 131 | if __name__ == "__main__": 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 134 | parser.add_argument("--model-base", type=str, default=None) 135 | parser.add_argument("--image-file", type=str, required=True) 136 | parser.add_argument("--query", type=str, required=True) 137 | parser.add_argument("--conv-mode", type=str, default=None) 138 | parser.add_argument("--sep", type=str, default=",") 139 | parser.add_argument("--temperature", type=float, default=0.2) 140 | parser.add_argument("--top_p", type=float, default=None) 141 | parser.add_argument("--num_beams", type=int, default=1) 142 | parser.add_argument("--max_new_tokens", type=int, default=512) 143 | args = parser.parse_args() 144 | 145 | eval_model(args) 146 | -------------------------------------------------------------------------------- /llava/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | import argparse 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 11 | parser.add_argument('-d', '--dir', default=None) 12 | parser.add_argument('-v', '--version', default=None) 13 | parser.add_argument('-s', '--select', nargs='*', default=None) 14 | parser.add_argument('-f', '--files', nargs='*', default=[]) 15 | parser.add_argument('-i', '--ignore', nargs='*', default=[]) 16 | return parser.parse_args() 17 | 18 | 19 | if __name__ == '__main__': 20 | args = parse_args() 21 | 22 | if args.ignore is not None: 23 | args.ignore = [int(x) for x in args.ignore] 24 | 25 | if len(args.files) > 0: 26 | review_files = args.files 27 | else: 28 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_') or 'review' in args.dir)] 29 | 30 | for review_file in sorted(review_files): 31 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') 32 | if args.select is not None and any(x not in config for x in args.select): 33 | continue 34 | if '0613' in config: 35 | version = '0613' 36 | else: 37 | version = '0314' 38 | if args.version is not None and args.version != version: 39 | continue 40 | scores = defaultdict(list) 41 | print(config) 42 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: 43 | for review_str in f: 44 | review = json.loads(review_str) 45 | if review['question_id'] in args.ignore: 46 | continue 47 | if 'category' in review: 48 | scores[review['category']].append(review['tuple']) 49 | scores['all'].append(review['tuple']) 50 | else: 51 | if 'tuple' in review: 52 | scores['all'].append(review['tuple']) 53 | else: 54 | scores['all'].append(review['score']) 55 | for k, v in sorted(scores.items()): 56 | stats = np.asarray(v).mean(0).tolist() 57 | stats = [round(x, 3) for x in stats] 58 | # print(k, stats, round(stats[1]/stats[0]*100, 1)) 59 | print(k, round(stats[1]/stats[0]*100, 1), round(stats[0] * 10, 1), round(stats[1] * 10, 1)) 60 | print('=================================') 61 | -------------------------------------------------------------------------------- /llava/eval/table/model.jsonl: -------------------------------------------------------------------------------- 1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"} 2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"} 3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"} 4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"} 5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"} 6 | -------------------------------------------------------------------------------- /llava/eval/table/prompt.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"} 2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"} 3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"} 4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/table/reviewer.jsonl: -------------------------------------------------------------------------------- 1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"} 2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"} 3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/alpaca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/eval/webpage/figures/alpaca.png -------------------------------------------------------------------------------- /llava/eval/webpage/figures/bard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/eval/webpage/figures/bard.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/chatgpt.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/llama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/eval/webpage/figures/llama.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/vicuna.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/eval/webpage/figures/vicuna.jpeg -------------------------------------------------------------------------------- /llava/eval/webpage/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 3 | background-color: #f8f9fa; 4 | } 5 | 6 | .navbar-dark .navbar-nav .nav-link { 7 | color: #f1cf68; 8 | font-size: 1.1rem; 9 | padding: 0.5rem 0.6rem; 10 | } 11 | 12 | .card-header { 13 | font-weight: bold; 14 | } 15 | 16 | .card { 17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); 18 | transition: 0.3s; 19 | } 20 | 21 | .card:hover { 22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); 23 | } 24 | 25 | button { 26 | transition: background-color 0.3s; 27 | } 28 | 29 | button:hover { 30 | background-color: #007bff; 31 | } 32 | 33 | @media (max-width: 767px) { 34 | .form-row .form-group { 35 | margin-bottom: 10px; 36 | } 37 | } 38 | 39 | /* Extra styles */ 40 | 41 | .expandable-card .card-text-container { 42 | max-height: 200px; 43 | overflow-y: hidden; 44 | position: relative; 45 | } 46 | 47 | .expandable-card.expanded .card-text-container { 48 | max-height: none; 49 | } 50 | 51 | .expand-btn { 52 | position: relative; 53 | display: none; 54 | background-color: rgba(255, 255, 255, 0.8); 55 | color: #510c75; 56 | border-color: transparent; 57 | } 58 | 59 | .expand-btn:hover { 60 | background-color: rgba(200, 200, 200, 0.8); 61 | text-decoration: none; 62 | border-color: transparent; 63 | color: #510c75; 64 | } 65 | 66 | .expand-btn:focus { 67 | outline: none; 68 | text-decoration: none; 69 | } 70 | 71 | .expandable-card:not(.expanded) .card-text-container:after { 72 | content: ""; 73 | position: absolute; 74 | bottom: 0; 75 | left: 0; 76 | width: 100%; 77 | height: 90px; 78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); 79 | } 80 | 81 | .expandable-card:not(.expanded) .expand-btn { 82 | margin-top: -40px; 83 | } 84 | 85 | .card-body { 86 | padding-bottom: 5px; 87 | } 88 | 89 | .vertical-flex-layout { 90 | justify-content: center; 91 | align-items: center; 92 | height: 100%; 93 | display: flex; 94 | flex-direction: column; 95 | gap: 5px; 96 | } 97 | 98 | .figure-img { 99 | max-width: 100%; 100 | height: auto; 101 | } 102 | 103 | .adjustable-font-size { 104 | font-size: calc(0.5rem + 2vw); 105 | } 106 | -------------------------------------------------------------------------------- /llava/garmentcodeRC_utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import yaml 4 | from pathlib import Path 5 | from collections import OrderedDict 6 | import pickle as pkl 7 | import argparse 8 | import json 9 | import re 10 | import copy 11 | import torch 12 | import numpy as np 13 | 14 | wb_config_name = 'waistband' 15 | skirt_configs = { 16 | 'SkirtCircle': 'flare-skirt', 17 | 'AsymmSkirtCircle': 'flare-skirt', 18 | 'GodetSkirt': 'godet-skirt', 19 | 'Pants': 'pants', 20 | 'Skirt2': 'skirt', 21 | 'SkirtManyPanels': 'flare-skirt', 22 | 'PencilSkirt': 'pencil-skirt', 23 | 'SkirtLevels': 'levels-skirt', 24 | } 25 | all_skirt_configs = ['skirt', 'flare-skirt', 'godet-skirt', 'pencil-skirt', 'levels-skirt', 'pants'] 26 | 27 | 28 | 29 | def ordered(d, desired_key_order): 30 | return OrderedDict([(key, d[key]) for key in desired_key_order]) 31 | 32 | 33 | def recursive_simplify_params(cfg, is_used=True, unused_configs=[], parent_path='design', device='cpu'): 34 | # change float to 4 decimal places 35 | if cfg is None: 36 | print(parent_path) 37 | 38 | cfg_new = {} 39 | if ('type' not in cfg) or not isinstance(cfg['type'], str): 40 | 41 | if 'enable_asym' in cfg: ############################################ 42 | enable_asym = bool(cfg['enable_asym']['v']) 43 | if not enable_asym: 44 | cfg_new['enable_asym'] = cfg['enable_asym']['v'] 45 | return cfg_new 46 | 47 | if parent_path == 'design.sleeve.cuff' and cfg['type']['v'] is None: 48 | return {'type': None} 49 | 50 | if parent_path == 'design.left.sleeve.cuff' and cfg['type']['v'] is None: 51 | return {'type': None} 52 | 53 | if parent_path == 'design.pants.cuff' and cfg['type']['v'] is None: 54 | return {'type': None} 55 | 56 | # if parent_path == 'design.sleeve' and cfg['sleeveless']['v']: 57 | # return {'type': None} 58 | 59 | # if parent_path == 'design.sleeve' 60 | 61 | for subpattern_n, subpattern_cfg in cfg.items(): 62 | if (subpattern_n in unused_configs) and ('meta' in cfg): 63 | continue 64 | else: 65 | subconfig = recursive_simplify_params(subpattern_cfg, is_used=is_used, parent_path=parent_path + '.' + subpattern_n, device=device) 66 | 67 | cfg_new[subpattern_n] = subconfig 68 | 69 | else: 70 | type_now = cfg['type'] 71 | if type_now == 'float': 72 | lower_bd = float(cfg['range'][0]) 73 | upper_bd = float(cfg['range'][1]) 74 | 75 | float_val = cfg['v'] 76 | float_val_normed = (float_val - lower_bd) / (upper_bd - lower_bd) 77 | cfg_new = torch.tensor([float_val_normed]).float().to(device) 78 | 79 | else: 80 | cfg_new = cfg['v'] 81 | 82 | return cfg_new 83 | 84 | 85 | def GarmentCodeRC_simplify_params(new_config, device='cpu'): 86 | if 'design' in new_config: 87 | new_config = new_config['design'] 88 | 89 | ################ get unused_configs 90 | unused_configs = [] 91 | ub_garment = new_config['meta']['upper']['v'] 92 | if ub_garment is None: 93 | unused_configs += ['shirt', 'collar', 'sleeve', 'left'] 94 | 95 | wb_garment = new_config['meta']['wb']['v'] 96 | if not wb_garment: 97 | unused_configs.append(wb_config_name) 98 | 99 | lower_garment = new_config['meta']['bottom']['v'] 100 | assert lower_garment != 'null', (lower_garment) 101 | if lower_garment is None: 102 | unused_configs += all_skirt_configs 103 | else: 104 | unused_configs += copy.deepcopy(all_skirt_configs) 105 | unused_configs.remove(skirt_configs[lower_garment]) 106 | 107 | if 'base' in new_config[skirt_configs[lower_garment]]: 108 | base_garment = new_config[skirt_configs[lower_garment]]['base']['v'] 109 | unused_configs.remove(skirt_configs[base_garment]) 110 | 111 | new_config = recursive_simplify_params(new_config, is_used=True, unused_configs=unused_configs, device=device) 112 | 113 | return new_config 114 | 115 | 116 | def update_design_ranges(): 117 | return -------------------------------------------------------------------------------- /llava/lisa_utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import numpy as np 4 | import torch 5 | import torch.distributed as dist 6 | 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | 14 | SHORT_QUESTION_LIST = [ 15 | DEFAULT_IMAGE_TOKEN + "\n" + "Can you predict the SMPL pose of the person in this image?", 16 | DEFAULT_IMAGE_TOKEN + "\n" + "There is person in the middle of the image, please output this person's SMPL pose.", 17 | DEFAULT_IMAGE_TOKEN 18 | + "\n" 19 | + "What is the human pose in this image? Please respond with SMPL pose.", 20 | DEFAULT_IMAGE_TOKEN 21 | + "\n" 22 | + "What is the person doing in this image? Please output SMPL pose.", 23 | DEFAULT_IMAGE_TOKEN + "\n" + "There is person in the middle of the image, please output this person's SMPL pose.", 24 | ] 25 | 26 | LONG_QUESTION_LIST = [ 27 | DEFAULT_IMAGE_TOKEN + "\n" + "{sent} Please respond with SMPL pose.", 28 | DEFAULT_IMAGE_TOKEN + "\n" + "{sent} Please output SMPL pose.", 29 | ] 30 | 31 | EXPLANATORY_QUESTION_LIST = [ 32 | "Please output SMPL pose and explain the pose.", 33 | "Please output SMPL pose and explain the reason.", 34 | "Please output SMPL pose and give some explanation.", 35 | ] 36 | 37 | ANSWER_LIST = [ 38 | "It is [SEG].", 39 | "Sure, [SEG].", 40 | "Sure, it is [SEG].", 41 | "Sure, the SMPL pose is [SEG].", 42 | "[SEG].", 43 | ] 44 | 45 | 46 | class Summary(Enum): 47 | NONE = 0 48 | AVERAGE = 1 49 | SUM = 2 50 | COUNT = 3 51 | 52 | 53 | class AverageMeter(object): 54 | """Computes and stores the average and current value""" 55 | 56 | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): 57 | self.name = name 58 | self.fmt = fmt 59 | self.summary_type = summary_type 60 | self.reset() 61 | 62 | def reset(self): 63 | self.val = 0 64 | self.avg = 0 65 | self.sum = 0 66 | self.count = 0 67 | 68 | def update(self, val, n=1): 69 | self.val = val 70 | self.sum += val * n 71 | self.count += n 72 | self.avg = self.sum / self.count 73 | 74 | def all_reduce(self): 75 | device = "cuda" if torch.cuda.is_available() else "cpu" 76 | if isinstance(self.sum, np.ndarray): 77 | total = torch.tensor( 78 | self.sum.tolist() 79 | + [ 80 | self.count, 81 | ], 82 | dtype=torch.float32, 83 | device=device, 84 | ) 85 | else: 86 | total = torch.tensor( 87 | [self.sum, self.count], dtype=torch.float32, device=device 88 | ) 89 | 90 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) 91 | if total.shape[0] > 2: 92 | self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item() 93 | else: 94 | self.sum, self.count = total.tolist() 95 | self.avg = self.sum / (self.count + 1e-5) 96 | 97 | def __str__(self): 98 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 99 | return fmtstr.format(**self.__dict__) 100 | 101 | def summary(self): 102 | fmtstr = "" 103 | if self.summary_type is Summary.NONE: 104 | fmtstr = "" 105 | elif self.summary_type is Summary.AVERAGE: 106 | fmtstr = "{name} {avg:.3f}" 107 | elif self.summary_type is Summary.SUM: 108 | fmtstr = "{name} {sum:.3f}" 109 | elif self.summary_type is Summary.COUNT: 110 | fmtstr = "{name} {count:.3f}" 111 | else: 112 | raise ValueError("invalid summary type %r" % self.summary_type) 113 | 114 | return fmtstr.format(**self.__dict__) 115 | 116 | 117 | def intersectionAndUnionGPU(output, target, K, ignore_index=255): 118 | # 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1. 119 | assert output.dim() in [1, 2, 3] 120 | assert output.shape == target.shape 121 | output = output.view(-1) 122 | target = target.view(-1) 123 | output[target == ignore_index] = ignore_index 124 | intersection = output[output == target] 125 | area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1) 126 | area_output = torch.histc(output, bins=K, min=0, max=K - 1) 127 | area_target = torch.histc(target, bins=K, min=0, max=K - 1) 128 | area_union = area_output + area_target - area_intersection 129 | return area_intersection, area_union, area_target 130 | 131 | 132 | class ProgressMeter(object): 133 | def __init__(self, num_batches, meters, prefix=""): 134 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 135 | self.meters = meters 136 | self.prefix = prefix 137 | 138 | def display(self, batch): 139 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 140 | entries += [str(meter) for meter in self.meters] 141 | print("\t".join(entries)) 142 | 143 | def display_summary(self): 144 | entries = [" *"] 145 | entries += [meter.summary() for meter in self.meters] 146 | print(" ".join(entries)) 147 | 148 | def _get_batch_fmtstr(self, num_batches): 149 | num_digits = len(str(num_batches // 1)) 150 | fmt = "{:" + str(num_digits) + "d}" 151 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 152 | 153 | 154 | def dict_to_cuda(input_dict): 155 | for k, v in input_dict.items(): 156 | if isinstance(input_dict[k], torch.Tensor): 157 | input_dict[k] = v.cuda(non_blocking=True) 158 | elif ( 159 | isinstance(input_dict[k], list) 160 | and len(input_dict[k]) > 0 161 | and isinstance(input_dict[k][0], torch.Tensor) 162 | ): 163 | input_dict[k] = [ele.cuda(non_blocking=True) for ele in v] 164 | return input_dict 165 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 3 | from .language_model.llava_mpt import LlavaMptForCausalLM, LlavaMptConfig 4 | from .language_model.llava_mistral import LlavaMistralForCausalLM, LlavaMistralConfig 5 | from .language_model.llava_garment_float50 import GarmentGPTFloat50ForCausalLM 6 | except: 7 | pass 8 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from transformers import AutoConfig, AutoModelForCausalLM, \ 22 | LlamaConfig, LlamaModel, LlamaForCausalLM 23 | 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | from transformers.generation.utils import GenerateOutput 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava_llama" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | self.pretraining_tp = config.pretraining_tp 48 | self.vocab_size = config.vocab_size 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | 91 | return super().forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | position_ids=position_ids, 95 | past_key_values=past_key_values, 96 | inputs_embeds=inputs_embeds, 97 | labels=labels, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict 102 | ) 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | images: Optional[torch.Tensor] = None, 109 | image_sizes: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | if "inputs_embeds" in kwargs: 115 | raise NotImplementedError("`inputs_embeds` is not supported") 116 | 117 | if images is not None: 118 | ( 119 | inputs, 120 | position_ids, 121 | attention_mask, 122 | _, 123 | inputs_embeds, 124 | _ 125 | ) = self.prepare_inputs_labels_for_multimodal( 126 | inputs, 127 | position_ids, 128 | attention_mask, 129 | None, 130 | None, 131 | images, 132 | image_sizes=image_sizes 133 | ) 134 | else: 135 | inputs_embeds = self.get_model().embed_tokens(inputs) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | images = kwargs.pop("images", None) 147 | image_sizes = kwargs.pop("image_sizes", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if images is not None: 152 | inputs['images'] = images 153 | if image_sizes is not None: 154 | inputs['image_sizes'] = image_sizes 155 | return inputs 156 | 157 | AutoConfig.register("llava_llama", LlavaConfig) 158 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 159 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | MistralConfig, MistralModel, MistralForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | from transformers.generation.utils import GenerateOutput 27 | 28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 29 | 30 | 31 | class LlavaMistralConfig(MistralConfig): 32 | model_type = "llava_mistral" 33 | 34 | 35 | class LlavaMistralModel(LlavaMetaModel, MistralModel): 36 | config_class = LlavaMistralConfig 37 | 38 | def __init__(self, config: MistralConfig): 39 | super(LlavaMistralModel, self).__init__(config) 40 | 41 | 42 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): 43 | config_class = LlavaMistralConfig 44 | 45 | def __init__(self, config): 46 | super(MistralForCausalLM, self).__init__(config) 47 | self.model = LlavaMistralModel(config) 48 | 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | 91 | return super().forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | position_ids=position_ids, 95 | past_key_values=past_key_values, 96 | inputs_embeds=inputs_embeds, 97 | labels=labels, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict 102 | ) 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | images: Optional[torch.Tensor] = None, 109 | image_sizes: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | if "inputs_embeds" in kwargs: 115 | raise NotImplementedError("`inputs_embeds` is not supported") 116 | 117 | if images is not None: 118 | ( 119 | inputs, 120 | position_ids, 121 | attention_mask, 122 | _, 123 | inputs_embeds, 124 | _ 125 | ) = self.prepare_inputs_labels_for_multimodal( 126 | inputs, 127 | position_ids, 128 | attention_mask, 129 | None, 130 | None, 131 | images, 132 | image_sizes=image_sizes 133 | ) 134 | else: 135 | inputs_embeds = self.get_model().embed_tokens(inputs) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | images = kwargs.pop("images", None) 147 | image_sizes = kwargs.pop("image_sizes", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if images is not None: 152 | inputs['images'] = images 153 | if image_sizes is not None: 154 | inputs['image_sizes'] = image_sizes 155 | return inputs 156 | 157 | AutoConfig.register("llava_mistral", LlavaMistralConfig) 158 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) 159 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Optional, Tuple 17 | 18 | import torch 19 | 20 | from transformers import AutoConfig, AutoModelForCausalLM, \ 21 | MptConfig, MptForCausalLM, MptModel 22 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 23 | 24 | 25 | class LlavaMptConfig(MptConfig): 26 | model_type = "llava_mpt" 27 | 28 | 29 | class LlavaMptModel(LlavaMetaModel, MptModel): 30 | config_class = LlavaMptConfig 31 | 32 | def __init__(self, config: MptConfig): 33 | config.hidden_size = config.d_model 34 | super(LlavaMptModel, self).__init__(config) 35 | 36 | def embed_tokens(self, x): 37 | return self.wte(x) 38 | 39 | 40 | class LlavaMptForCausalLM(MptForCausalLM, LlavaMetaForCausalLM): 41 | config_class = LlavaMptConfig 42 | supports_gradient_checkpointing = True 43 | 44 | def __init__(self, config): 45 | super(MptForCausalLM, self).__init__(config) 46 | 47 | self.transformer = LlavaMptModel(config) 48 | self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.transformer 55 | 56 | def _set_gradient_checkpointing(self, module, value=False): 57 | if isinstance(module, LlavaMptModel): 58 | module.gradient_checkpointing = value 59 | 60 | def forward( 61 | self, 62 | input_ids: Optional[torch.LongTensor] = None, 63 | past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, 64 | attention_mask: Optional[torch.Tensor] = None, 65 | inputs_embeds: Optional[torch.Tensor] = None, 66 | labels: Optional[torch.Tensor] = None, 67 | use_cache: Optional[bool] = None, 68 | output_attentions: Optional[bool] = None, 69 | output_hidden_states: Optional[bool] = None, 70 | return_dict: Optional[bool] = None, 71 | images=None): 72 | 73 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 74 | 75 | return super().forward( 76 | input_ids, 77 | past_key_values=past_key_values, 78 | attention_mask=attention_mask, 79 | inputs_embeds=inputs_embeds, 80 | labels=labels, 81 | use_cache=use_cache, 82 | output_attentions=output_attentions, 83 | output_hidden_states=output_hidden_states, 84 | return_dict=return_dict, 85 | ) 86 | 87 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 88 | images = kwargs.pop("images", None) 89 | _inputs = super().prepare_inputs_for_generation( 90 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 91 | ) 92 | _inputs['images'] = images 93 | return _inputs 94 | 95 | 96 | AutoConfig.register("llava_mpt", LlavaMptConfig) 97 | AutoModelForCausalLM.register(LlavaMptConfig, LlavaMptForCausalLM) 98 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower, CLIPVisionTowerS2 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | use_s2 = getattr(vision_tower_cfg, 's2', False) 9 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower: 10 | if use_s2: 11 | return CLIPVisionTowerS2(vision_tower, args=vision_tower_cfg, **kwargs) 12 | else: 13 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 14 | 15 | raise ValueError(f'Unknown vision tower: {vision_tower}') 16 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | elif getattr(args, 'unfreeze_mm_vision_tower', False): 20 | self.load_model() 21 | else: 22 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 23 | 24 | def load_model(self, device_map=None): 25 | if self.is_loaded: 26 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 27 | return 28 | 29 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 30 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 31 | self.vision_tower.requires_grad_(False) 32 | 33 | self.is_loaded = True 34 | 35 | def feature_select(self, image_forward_outs): 36 | image_features = image_forward_outs.hidden_states[self.select_layer] 37 | if self.select_feature == 'patch': 38 | image_features = image_features[:, 1:] 39 | elif self.select_feature == 'cls_patch': 40 | image_features = image_features 41 | else: 42 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 43 | return image_features 44 | 45 | @torch.no_grad() 46 | def forward(self, images): 47 | if type(images) is list: 48 | image_features = [] 49 | for image in images: 50 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 51 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 52 | image_features.append(image_feature) 53 | else: 54 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 55 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 56 | 57 | return image_features 58 | 59 | @property 60 | def dummy_feature(self): 61 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 62 | 63 | @property 64 | def dtype(self): 65 | return self.vision_tower.dtype 66 | 67 | @property 68 | def device(self): 69 | return self.vision_tower.device 70 | 71 | @property 72 | def config(self): 73 | if self.is_loaded: 74 | return self.vision_tower.config 75 | else: 76 | return self.cfg_only 77 | 78 | @property 79 | def hidden_size(self): 80 | return self.config.hidden_size 81 | 82 | @property 83 | def num_patches_per_side(self): 84 | return self.config.image_size // self.config.patch_size 85 | 86 | @property 87 | def num_patches(self): 88 | return (self.config.image_size // self.config.patch_size) ** 2 89 | 90 | 91 | 92 | class CLIPVisionTowerS2(CLIPVisionTower): 93 | def __init__(self, vision_tower, args, delay_load=False): 94 | super().__init__(vision_tower, args, delay_load) 95 | 96 | self.s2_scales = getattr(args, 's2_scales', '336,672,1008') 97 | self.s2_scales = list(map(int, self.s2_scales.split(','))) 98 | self.s2_scales.sort() 99 | self.s2_split_size = self.s2_scales[0] 100 | self.s2_image_size = self.s2_scales[-1] 101 | 102 | try: 103 | from s2wrapper import forward as multiscale_forward 104 | except ImportError: 105 | raise ImportError('Package s2wrapper not found! Please install by running: \npip install git+https://github.com/bfshi/scaling_on_scales.git') 106 | self.multiscale_forward = multiscale_forward 107 | 108 | # change resize/crop size in preprocessing to the largest image size in s2_scale 109 | if not delay_load or getattr(args, 'unfreeze_mm_vision_tower', False): 110 | self.image_processor.size['shortest_edge'] = self.s2_image_size 111 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size 112 | 113 | def load_model(self, device_map=None): 114 | if self.is_loaded: 115 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 116 | return 117 | 118 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 119 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 120 | self.vision_tower.requires_grad_(False) 121 | 122 | self.image_processor.size['shortest_edge'] = self.s2_image_size 123 | self.image_processor.crop_size['height'] = self.image_processor.crop_size['width'] = self.s2_image_size 124 | 125 | self.is_loaded = True 126 | 127 | @torch.no_grad() 128 | def forward_feature(self, images): 129 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 130 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 131 | return image_features 132 | 133 | @torch.no_grad() 134 | def forward(self, images): 135 | if type(images) is list: 136 | image_features = [] 137 | for image in images: 138 | image_feature = self.multiscale_forward(self.forward_feature, image.unsqueeze(0), img_sizes=self.s2_scales, max_split_size=self.s2_split_size) 139 | image_features.append(image_feature) 140 | else: 141 | image_features = self.multiscale_forward(self.forward_feature, images, img_sizes=self.s2_scales, max_split_size=self.s2_split_size) 142 | 143 | return image_features 144 | 145 | @property 146 | def hidden_size(self): 147 | return self.config.hidden_size * len(self.s2_scales) 148 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /llava/model/smplx/joint_names.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | JOINT_NAMES = [ 18 | 'pelvis', # 0 19 | 'left_hip', # 1 20 | 'right_hip', # 2 21 | 'spine1', # 3 22 | 'left_knee', # 4 23 | 'right_knee', # 5 24 | 'spine2', # 6 25 | 'left_ankle', # 7 26 | 'right_ankle', # 8 27 | 'spine3', # 9, (10 - 1) 28 | 'left_foot', # 10 29 | 'right_foot', # 11 30 | 'neck', # 12 31 | 'left_collar', # 13 32 | 'right_collar', # 14, (15 - 1) 33 | 'head', # 15 34 | 'left_shoulder', # 16 35 | 'right_shoulder', # 17 36 | 'left_elbow', # 18 37 | 'right_elbow', # 19, (20 - 1) 38 | 'left_wrist', # 20 39 | 'right_wrist', # 21 40 | 'jaw', # 22, (23 - 1) 41 | 'left_eye_smplhf', # 23 42 | 'right_eye_smplhf', # 24, (25 - 1) 43 | 'left_index1', # 25 44 | 'left_index2', # 26 45 | 'left_index3', # 27 46 | 'left_middle1', # 28 47 | 'left_middle2', # 29 48 | 'left_middle3', # 30 49 | 'left_pinky1', # 31 50 | 'left_pinky2', # 32 51 | 'left_pinky3', # 33 52 | 'left_ring1', # 34 53 | 'left_ring2', # 35 54 | 'left_ring3', # 36 55 | 'left_thumb1', # 37 56 | 'left_thumb2', # 38 57 | 'left_thumb3', # 39, (40 - 1) 58 | 'right_index1', # 40 59 | 'right_index2', # 41 60 | 'right_index3', # 42 61 | 'right_middle1', # 43 62 | 'right_middle2', # 44 63 | 'right_middle3', # 45 64 | 'right_pinky1', # 46 65 | 'right_pinky2', # 47 66 | 'right_pinky3', # 48 67 | 'right_ring1', # 49 68 | 'right_ring2', # 50 69 | 'right_ring3', # 51 70 | 'right_thumb1', # 52 71 | 'right_thumb2', # 53 72 | 'right_thumb3', # 54, (55 - 1) 73 | 'nose', # 55 74 | 'right_eye', # 56 75 | 'left_eye', # 57 76 | 'right_ear', # 58 77 | 'left_ear', # 59 78 | 'left_big_toe', # 60 79 | 'left_small_toe', # 61 80 | 'left_heel', # 62 81 | 'right_big_toe', # 63 82 | 'right_small_toe', # 64, (65 - 1) 83 | 'right_heel', # 65 84 | 'left_thumb', # 66 85 | 'left_index', # 67 86 | 'left_middle', # 68 87 | 'left_ring', # 69, (70 - 1) 88 | 'left_pinky', # 70 89 | 'right_thumb', # 71 90 | 'right_index', # 72 91 | 'right_middle', # 73 92 | 'right_ring', # 74, (75 - 1) 93 | 'right_pinky', # 75, (76 - 1) 94 | # evaluated face jts (76 - 127) 95 | 'right_eye_brow1', # 76 96 | 'right_eye_brow2', 97 | 'right_eye_brow3', 98 | 'right_eye_brow4', 99 | 'right_eye_brow5', 100 | 'left_eye_brow5', 101 | 'left_eye_brow4', 102 | 'left_eye_brow3', 103 | 'left_eye_brow2', 104 | 'left_eye_brow1', 105 | 'nose1', 106 | 'nose2', 107 | 'nose3', 108 | 'nose4', 109 | 'right_nose_2', 110 | 'right_nose_1', 111 | 'nose_middle', 112 | 'left_nose_1', 113 | 'left_nose_2', 114 | 'right_eye1', 115 | 'right_eye2', 116 | 'right_eye3', 117 | 'right_eye4', 118 | 'right_eye5', 119 | 'right_eye6', 120 | 'left_eye4', 121 | 'left_eye3', 122 | 'left_eye2', 123 | 'left_eye1', 124 | 'left_eye6', 125 | 'left_eye5', 126 | 'right_mouth_1', 127 | 'right_mouth_2', 128 | 'right_mouth_3', 129 | 'mouth_top', 130 | 'left_mouth_3', 131 | 'left_mouth_2', 132 | 'left_mouth_1', 133 | 'left_mouth_5', # 59 in OpenPose output 134 | 'left_mouth_4', # 58 in OpenPose output 135 | 'mouth_bottom', # 116 => 116 - 76 = 40 => the 40-index item in lmk_faces_idx 136 | 'right_mouth_4', 137 | 'right_mouth_5', 138 | 'right_lip_1', 139 | 'right_lip_2', 140 | 'lip_top', 141 | 'left_lip_2', 142 | 'left_lip_1', 143 | 'left_lip_3', 144 | 'lip_bottom', 145 | 'right_lip_3', 146 | # Face contour 147 | 'right_contour_1', 148 | 'right_contour_2', 149 | 'right_contour_3', 150 | 'right_contour_4', 151 | 'right_contour_5', 152 | 'right_contour_6', 153 | 'right_contour_7', 154 | 'right_contour_8', 155 | 'contour_middle', 156 | 'left_contour_8', 157 | 'left_contour_7', 158 | 'left_contour_6', 159 | 'left_contour_5', 160 | 'left_contour_4', 161 | 'left_contour_3', 162 | 'left_contour_2', 163 | 'left_contour_1', 164 | ] 165 | 166 | 167 | SMPLH_JOINT_NAMES = [ 168 | 'pelvis', 169 | 'left_hip', 170 | 'right_hip', 171 | 'spine1', 172 | 'left_knee', 173 | 'right_knee', 174 | 'spine2', 175 | 'left_ankle', 176 | 'right_ankle', 177 | 'spine3', 178 | 'left_foot', # 10 179 | 'right_foot', # 11 180 | 'neck', 181 | 'left_collar', 182 | 'right_collar', 183 | 'head', 184 | 'left_shoulder', 185 | 'right_shoulder', 186 | 'left_elbow', 187 | 'right_elbow', 188 | 'left_wrist', # 20 189 | 'right_wrist', # 21 190 | 'left_index1', 191 | 'left_index2', 192 | 'left_index3', 193 | 'left_middle1', # 25 194 | 'left_middle2', 195 | 'left_middle3', 196 | 'left_pinky1', 197 | 'left_pinky2', 198 | 'left_pinky3', # 30 199 | 'left_ring1', 200 | 'left_ring2', 201 | 'left_ring3', 202 | 'left_thumb1', 203 | 'left_thumb2', 204 | 'left_thumb3', 205 | 'right_index1', 206 | 'right_index2', 207 | 'right_index3', 208 | 'right_middle1', 209 | 'right_middle2', # 41 210 | 'right_middle3', 211 | 'right_pinky1', 212 | 'right_pinky2', 213 | 'right_pinky3', 214 | 'right_ring1', 215 | 'right_ring2', 216 | 'right_ring3', 217 | 'right_thumb1', 218 | 'right_thumb2', 219 | 'right_thumb3', 220 | 'nose', 221 | 'right_eye', 222 | 'left_eye', 223 | 'right_ear', 224 | 'left_ear', 225 | 'left_big_toe', 226 | 'left_small_toe', 227 | 'left_heel', 228 | 'right_big_toe', 229 | 'right_small_toe', 230 | 'right_heel', 231 | 'left_thumb', 232 | 'left_index', 233 | 'left_middle', 234 | 'left_ring', 235 | 'left_pinky', 236 | 'right_thumb', 237 | 'right_index', 238 | 'right_middle', 239 | 'right_ring', 240 | 'right_pinky', 241 | ] -------------------------------------------------------------------------------- /llava/model/smplx/smplx_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .lbs import * 4 | 5 | def calculate_A(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents, pose2rot=True): 6 | batch_size = max(betas.shape[0], pose.shape[0]) 7 | device, dtype = betas.device, betas.dtype 8 | 9 | # Add shape contribution 10 | v_shaped = v_template + blend_shapes(betas, shapedirs) 11 | 12 | # Get the joints 13 | # NxJx3 array 14 | J = vertices2joints(J_regressor, v_shaped) 15 | 16 | # 3. Add pose blend shapes 17 | # N x J x 3 x 3 18 | ident = torch.eye(3, dtype=dtype, device=device) 19 | if pose2rot: 20 | rot_mats = batch_rodrigues(pose.view(-1, 3)).view( 21 | [batch_size, -1, 3, 3]) 22 | 23 | pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1]) 24 | # (N x P) x (P, V * 3) -> N x V x 3 25 | pose_offsets = torch.matmul( 26 | pose_feature, posedirs).view(batch_size, -1, 3) 27 | else: 28 | pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident 29 | rot_mats = pose.view(batch_size, -1, 3, 3) 30 | 31 | pose_offsets = torch.matmul(pose_feature.view(batch_size, -1), 32 | posedirs).view(batch_size, -1, 3) 33 | 34 | v_posed = pose_offsets + v_shaped 35 | # 4. Get the global joint location 36 | J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype) 37 | return A -------------------------------------------------------------------------------- /llava/model/smplx/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from typing import NewType, Union, Optional 18 | from dataclasses import dataclass, asdict, fields 19 | import numpy as np 20 | import torch 21 | 22 | Tensor = NewType('Tensor', torch.Tensor) 23 | Array = NewType('Array', np.ndarray) 24 | 25 | 26 | @dataclass 27 | class ModelOutput: 28 | vertices: Optional[Tensor] = None 29 | joints: Optional[Tensor] = None 30 | full_pose: Optional[Tensor] = None 31 | global_orient: Optional[Tensor] = None 32 | transl: Optional[Tensor] = None 33 | v_shaped: Optional[Tensor] = None 34 | 35 | def __getitem__(self, key): 36 | return getattr(self, key) 37 | 38 | def get(self, key, default=None): 39 | return getattr(self, key, default) 40 | 41 | def __iter__(self): 42 | return self.keys() 43 | 44 | def keys(self): 45 | keys = [t.name for t in fields(self)] 46 | return iter(keys) 47 | 48 | def values(self): 49 | values = [getattr(self, t.name) for t in fields(self)] 50 | return iter(values) 51 | 52 | def items(self): 53 | data = [(t.name, getattr(self, t.name)) for t in fields(self)] 54 | return iter(data) 55 | 56 | 57 | @dataclass 58 | class SMPLOutput(ModelOutput): 59 | betas: Optional[Tensor] = None 60 | body_pose: Optional[Tensor] = None 61 | 62 | 63 | @dataclass 64 | class SMPLHOutput(SMPLOutput): 65 | left_hand_pose: Optional[Tensor] = None 66 | right_hand_pose: Optional[Tensor] = None 67 | transl: Optional[Tensor] = None 68 | 69 | 70 | @dataclass 71 | class SMPLXOutput(SMPLHOutput): 72 | expression: Optional[Tensor] = None 73 | jaw_pose: Optional[Tensor] = None 74 | 75 | 76 | @dataclass 77 | class MANOOutput(ModelOutput): 78 | betas: Optional[Tensor] = None 79 | hand_pose: Optional[Tensor] = None 80 | 81 | 82 | @dataclass 83 | class FLAMEOutput(ModelOutput): 84 | betas: Optional[Tensor] = None 85 | expression: Optional[Tensor] = None 86 | jaw_pose: Optional[Tensor] = None 87 | neck_pose: Optional[Tensor] = None 88 | 89 | 90 | def find_joint_kin_chain(joint_id, kinematic_tree): 91 | kin_chain = [] 92 | curr_idx = joint_id 93 | while curr_idx != -1: 94 | kin_chain.append(curr_idx) 95 | curr_idx = kinematic_tree[curr_idx] 96 | return kin_chain 97 | 98 | 99 | def to_tensor( 100 | array: Union[Array, Tensor], dtype=torch.float32 101 | ) -> Tensor: 102 | if torch.is_tensor(array): 103 | return array.contiguous() 104 | else: 105 | return torch.tensor(array, dtype=dtype).contiguous() 106 | 107 | 108 | class Struct(object): 109 | def __init__(self, **kwargs): 110 | for key, val in kwargs.items(): 111 | setattr(self, key, val) 112 | 113 | 114 | def to_np(array, dtype=np.float32): 115 | if 'scipy.sparse' in str(type(array)): 116 | array = array.todense() 117 | return np.array(array, dtype=dtype) 118 | 119 | 120 | def rot_mat_to_euler(rot_mats): 121 | # Calculates rotation matrix to euler angles 122 | # Careful for extreme cases of eular angles like [0.0, pi, 0.0] 123 | 124 | sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] + 125 | rot_mats[:, 1, 0] * rot_mats[:, 1, 0]) 126 | return torch.atan2(-rot_mats[:, 2, 0], sy) 127 | -------------------------------------------------------------------------------- /llava/model/smplx/vertex_ids.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from __future__ import print_function 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | 21 | # Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to 22 | # MSCOCO and OpenPose joints 23 | vertex_ids = { 24 | 'smplh': { 25 | 'nose': 332, 26 | 'reye': 6260, 27 | 'leye': 2800, 28 | 'rear': 4071, 29 | 'lear': 583, 30 | 'rthumb': 6191, 31 | 'rindex': 5782, 32 | 'rmiddle': 5905, 33 | 'rring': 6016, 34 | 'rpinky': 6133, 35 | 'lthumb': 2746, 36 | 'lindex': 2319, 37 | 'lmiddle': 2445, 38 | 'lring': 2556, 39 | 'lpinky': 2673, 40 | 'LBigToe': 3216, 41 | 'LSmallToe': 3226, 42 | 'LHeel': 3387, 43 | 'RBigToe': 6617, 44 | 'RSmallToe': 6624, 45 | 'RHeel': 6787 46 | }, 47 | 'smplx': { 48 | 'nose': 9120, 49 | 'reye': 9929, 50 | 'leye': 9448, 51 | 'rear': 616, 52 | 'lear': 6, 53 | 'rthumb': 8079, 54 | 'rindex': 7669, 55 | 'rmiddle': 7794, 56 | 'rring': 7905, 57 | 'rpinky': 8022, 58 | 'lthumb': 5361, 59 | 'lindex': 4933, 60 | 'lmiddle': 5058, 61 | 'lring': 5169, 62 | 'lpinky': 5286, 63 | 'LBigToe': 5770, 64 | 'LSmallToe': 5780, 65 | 'LHeel': 8846, 66 | 'RBigToe': 8463, 67 | 'RSmallToe': 8474, 68 | 'RHeel': 8635 69 | }, 70 | 'mano': { 71 | 'thumb': 744, 72 | 'index': 320, 73 | 'middle': 443, 74 | 'ring': 554, 75 | 'pinky': 671, 76 | } 77 | } 78 | -------------------------------------------------------------------------------- /llava/model/smplx/vertex_joint_selector.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is 4 | # holder of all proprietary rights on this computer program. 5 | # You can only use this computer program if you have closed 6 | # a license agreement with MPG or you get the right to use the computer 7 | # program from someone who is authorized to grant you that right. 8 | # Any use of the computer program without a valid license is prohibited and 9 | # liable to prosecution. 10 | # 11 | # Copyright©2019 Max-Planck-Gesellschaft zur Förderung 12 | # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute 13 | # for Intelligent Systems. All rights reserved. 14 | # 15 | # Contact: ps-license@tuebingen.mpg.de 16 | 17 | from __future__ import absolute_import 18 | from __future__ import print_function 19 | from __future__ import division 20 | 21 | import numpy as np 22 | 23 | import torch 24 | import torch.nn as nn 25 | 26 | from .utils import to_tensor 27 | 28 | 29 | class VertexJointSelector(nn.Module): 30 | 31 | def __init__(self, vertex_ids=None, 32 | use_hands=True, 33 | use_feet_keypoints=True, **kwargs): 34 | super(VertexJointSelector, self).__init__() 35 | 36 | extra_joints_idxs = [] 37 | 38 | face_keyp_idxs = np.array([ 39 | vertex_ids['nose'], 40 | vertex_ids['reye'], 41 | vertex_ids['leye'], 42 | vertex_ids['rear'], 43 | vertex_ids['lear']], dtype=np.int64) 44 | 45 | extra_joints_idxs = np.concatenate([extra_joints_idxs, 46 | face_keyp_idxs]) 47 | 48 | if use_feet_keypoints: 49 | feet_keyp_idxs = np.array([vertex_ids['LBigToe'], 50 | vertex_ids['LSmallToe'], 51 | vertex_ids['LHeel'], 52 | vertex_ids['RBigToe'], 53 | vertex_ids['RSmallToe'], 54 | vertex_ids['RHeel']], dtype=np.int32) 55 | 56 | extra_joints_idxs = np.concatenate( 57 | [extra_joints_idxs, feet_keyp_idxs]) 58 | 59 | if use_hands: 60 | self.tip_names = ['thumb', 'index', 'middle', 'ring', 'pinky'] 61 | 62 | tips_idxs = [] 63 | for hand_id in ['l', 'r']: 64 | for tip_name in self.tip_names: 65 | tips_idxs.append(vertex_ids[hand_id + tip_name]) 66 | 67 | extra_joints_idxs = np.concatenate( 68 | [extra_joints_idxs, tips_idxs]) 69 | 70 | self.register_buffer('extra_joints_idxs', 71 | to_tensor(extra_joints_idxs, dtype=torch.long)) 72 | 73 | def forward(self, vertices, joints): 74 | extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs) 75 | joints = torch.cat([joints, extra_joints], dim=1) 76 | 77 | return joints 78 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/prompts_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | from llava.json_fixer import repair_json 4 | 5 | 6 | def get_text_labels(gpt4o_results): 7 | result_dict = repair_json(gpt4o_results, return_objects=True) 8 | 9 | used_config_new = {} 10 | used_config_text = [] 11 | 12 | if "upper garment" in result_dict: 13 | used_config_now = { 14 | 'garment_name': result_dict["upper garment"][0], 15 | 'geometry_styles': result_dict["upper garment"][1], 16 | } 17 | used_config_new['upperbody_garment'] = used_config_now 18 | used_config_text.append( 19 | result_dict["upper garment"][1] + ', ' + result_dict["upper garment"][0] 20 | ) 21 | 22 | if "lower garment" in result_dict: 23 | used_config_now = { 24 | 'garment_name': result_dict["lower garment"][0], 25 | 'geometry_styles': result_dict["lower garment"][1], 26 | } 27 | used_config_new['lowerbody_garment'] = used_config_now 28 | used_config_text.append( 29 | result_dict["lower garment"][1] + ', ' + result_dict["lower garment"][0] 30 | ) 31 | 32 | if "wholebody garment" in result_dict: 33 | used_config_now = { 34 | 'garment_name': result_dict["wholebody garment"][0], 35 | 'geometry_styles': result_dict["wholebody garment"][1], 36 | } 37 | used_config_new['wholebody_garment'] = used_config_now 38 | used_config_text.append( 39 | result_dict["wholebody garment"][1] + ', ' + result_dict["wholebody garment"][0] 40 | ) 41 | 42 | return used_config_new, used_config_text 43 | 44 | 45 | 46 | def get_text_labels_detailed(gpt4o_results): 47 | gpt4o_results = gpt4o_results.strip() 48 | if "```" in gpt4o_results: 49 | gpt4o_results = gpt4o_results.split("```")[1] 50 | gpt4o_results = gpt4o_results.strip() 51 | if gpt4o_results.startswith('json') or gpt4o_results.startswith('Json') or gpt4o_results.startswith('JSON'): 52 | gpt4o_results = gpt4o_results[4:].strip() 53 | 54 | results = repair_json(gpt4o_results, return_objects=True) 55 | if len(results) < 2: 56 | return None 57 | 58 | if isinstance(results[1], str): 59 | try: 60 | results[1] = eval(results[1]) 61 | except: 62 | print('????') 63 | return None 64 | 65 | if len(results) > 2: 66 | extra_styles = results[2].split(',') 67 | results[1]['extra'] = [item.strip() for item in extra_styles] 68 | else: 69 | results[1]['extra'] = [] 70 | 71 | used_config_now = { 72 | 'garment_name': results[0], 73 | 'geometry_styles': results[1], 74 | } 75 | 76 | return used_config_now 77 | 78 | 79 | def get_text_labels_foredit(gpt4o_results): 80 | gpt4o_results = gpt4o_results.strip() 81 | if "```" in gpt4o_results: 82 | gpt4o_results = gpt4o_results.split("```")[1] 83 | gpt4o_results = gpt4o_results.strip() 84 | if gpt4o_results.startswith('json') or gpt4o_results.startswith('Json') or gpt4o_results.startswith('JSON'): 85 | gpt4o_results = gpt4o_results[4:].strip() 86 | 87 | results = repair_json(gpt4o_results, return_objects=True) 88 | return results 89 | 90 | 91 | def get_gpt4o_textgen_prompt(garment_name, garment_description): 92 | txtgen_prompt_path = 'docs/prompts/detailed_textbased_description.txt' 93 | with open(txtgen_prompt_path, 'r') as f: 94 | txtgen_prompt = f.read() 95 | 96 | txtgen_prompt = txtgen_prompt.replace('[TYPE]', garment_name) 97 | txtgen_prompt = txtgen_prompt.replace('[DESCRIPTION]', garment_description) 98 | return txtgen_prompt 99 | 100 | 101 | def get_gpt4o_edit_prompt(garment_name, prompt): 102 | edit_prompt_path = 'docs/prompts/prompt_garment_editing.txt' 103 | with open(edit_prompt_path, 'r') as f: 104 | edit_prompt = f.read() 105 | 106 | edit_prompt = edit_prompt.replace('[TYPE]', garment_name) 107 | edit_prompt = edit_prompt.replace('[DESCRIPTION]', prompt) 108 | return edit_prompt 109 | 110 | -------------------------------------------------------------------------------- /llava/pytorch3d_render_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | from pytorch3d.structures import Meshes 6 | from pytorch3d.renderer import ( 7 | PerspectiveCameras, 8 | OrthographicCameras, 9 | PointLights, 10 | RasterizationSettings, 11 | MeshRasterizer, 12 | HardPhongShader, 13 | MeshRenderer, 14 | SoftSilhouetteShader, 15 | TexturesUV, 16 | TexturesVertex, 17 | BlendParams) 18 | 19 | 20 | class TexturedIUVRenderer(nn.Module): 21 | def __init__(self, 22 | device='cuda', 23 | img_wh=256, 24 | blur_radius=0.0, 25 | faces_per_pixel=1, 26 | ): 27 | 28 | super().__init__() 29 | self.img_wh = img_wh 30 | 31 | raster_settings = RasterizationSettings(image_size=img_wh, 32 | blur_radius=blur_radius, 33 | faces_per_pixel=faces_per_pixel,) 34 | 35 | self.cameras = PerspectiveCameras() 36 | self.rasterizer = MeshRasterizer(cameras=self.cameras, raster_settings=raster_settings) # Specify camera in forward pass 37 | self.iuv_shader = SoftSilhouetteShader() 38 | 39 | self.to(device) 40 | 41 | def to(self, device): 42 | self.rasterizer.to(device) 43 | self.iuv_shader.to(device) 44 | 45 | def forward(self, vertices, faces, cam_t=None, cameras=None, focal_length=5000): 46 | img_wh = self.img_wh 47 | img_center=((img_wh * 0.5, img_wh * 0.5),) 48 | cameras = PerspectiveCameras(device=vertices.device, 49 | focal_length=focal_length, 50 | principal_point=img_center, 51 | image_size=((img_wh, img_wh),), 52 | in_ndc=False) 53 | device=vertices.device 54 | 55 | if cam_t is not None: 56 | vertices = vertices + cam_t[:, None, :] 57 | 58 | vertices = vertices * torch.tensor([-1., -1., 1.], device=device).float() 59 | 60 | textures_iuv = TexturesVertex(verts_features=torch.ones_like(vertices)) 61 | meshes_iuv = Meshes(verts=vertices, faces=faces, textures=textures_iuv) 62 | 63 | # Rasterize 64 | fragments = self.rasterizer(meshes_iuv, cameras=cameras) 65 | 66 | # Render RGB and IUV outputs 67 | iuv_image = self.iuv_shader(fragments, meshes_iuv) 68 | 69 | return iuv_image 70 | -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | 17 | 18 | def load_image(image_file): 19 | if image_file.startswith('http://') or image_file.startswith('https://'): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert('RGB') 22 | else: 23 | image = Image.open(image_file).convert('RGB') 24 | return image 25 | 26 | 27 | def main(args): 28 | # Model 29 | disable_torch_init() 30 | 31 | model_name = get_model_name_from_path(args.model_path) 32 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 33 | 34 | if "llama-2" in model_name.lower(): 35 | conv_mode = "llava_llama_2" 36 | elif "mistral" in model_name.lower(): 37 | conv_mode = "mistral_instruct" 38 | elif "v1.6-34b" in model_name.lower(): 39 | conv_mode = "chatml_direct" 40 | elif "v1" in model_name.lower(): 41 | conv_mode = "llava_v1" 42 | elif "mpt" in model_name.lower(): 43 | conv_mode = "mpt" 44 | else: 45 | conv_mode = "llava_v0" 46 | 47 | if args.conv_mode is not None and conv_mode != args.conv_mode: 48 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 49 | else: 50 | args.conv_mode = conv_mode 51 | 52 | conv = conv_templates[args.conv_mode].copy() 53 | if "mpt" in model_name.lower(): 54 | roles = ('user', 'assistant') 55 | else: 56 | roles = conv.roles 57 | 58 | image = load_image(args.image_file) 59 | image_size = image.size 60 | # Similar operation in model_worker.py 61 | image_tensor = process_images([image], image_processor, model.config) 62 | if type(image_tensor) is list: 63 | image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor] 64 | else: 65 | image_tensor = image_tensor.to(model.device, dtype=torch.float16) 66 | 67 | while True: 68 | try: 69 | inp = input(f"{roles[0]}: ") 70 | except EOFError: 71 | inp = "" 72 | if not inp: 73 | print("exit...") 74 | break 75 | 76 | print(f"{roles[1]}: ", end="") 77 | 78 | if image is not None: 79 | # first message 80 | if model.config.mm_use_im_start_end: 81 | inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp 82 | else: 83 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 84 | image = None 85 | 86 | conv.append_message(conv.roles[0], inp) 87 | conv.append_message(conv.roles[1], None) 88 | prompt = conv.get_prompt() 89 | 90 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device) 91 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 92 | keywords = [stop_str] 93 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 94 | 95 | with torch.inference_mode(): 96 | output_ids = model.generate( 97 | input_ids, 98 | images=image_tensor, 99 | image_sizes=[image_size], 100 | do_sample=True if args.temperature > 0 else False, 101 | temperature=args.temperature, 102 | max_new_tokens=args.max_new_tokens, 103 | streamer=streamer, 104 | use_cache=True) 105 | 106 | outputs = tokenizer.decode(output_ids[0]).strip() 107 | conv.messages[-1][-1] = outputs 108 | 109 | if args.debug: 110 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 111 | 112 | 113 | if __name__ == "__main__": 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 116 | parser.add_argument("--model-base", type=str, default=None) 117 | parser.add_argument("--image-file", type=str, required=True) 118 | parser.add_argument("--device", type=str, default="cuda") 119 | parser.add_argument("--conv-mode", type=str, default=None) 120 | parser.add_argument("--temperature", type=float, default=0.2) 121 | parser.add_argument("--max-new-tokens", type=int, default=512) 122 | parser.add_argument("--load-8bit", action="store_true") 123 | parser.add_argument("--load-4bit", action="store_true") 124 | parser.add_argument("--debug", action="store_true") 125 | args = parser.parse_args() 126 | main(args) 127 | -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/biansy000/ChatGarment/6ea4192383067938f1d1c63dc0cc4134ccc91b31/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /llava/train/train_mem_garmentcode_outfit.py: -------------------------------------------------------------------------------- 1 | from llava.train.train_garmentcode_outfit import train 2 | 3 | if __name__ == "__main__": 4 | train(attn_implementation="flash_attention_2") 5 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True, encoding='UTF-8') 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "llava" 7 | version = "1.2.2.post1" 8 | description = "Towards GPT-4 like large language and visual assistant." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.1.2", "torchvision==0.16.2", 17 | "transformers==4.37.2", "tokenizers==0.15.1", "sentencepiece==0.1.99", "shortuuid", 18 | "accelerate==0.32.0", "peft", "bitsandbytes", 19 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2", 20 | "gradio==4.16.0", "gradio_client==0.8.1", 21 | "requests", "httpx", "uvicorn", "fastapi", 22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 23 | "opencv-python", "easydict", "tensorboard", "peft==0.10.0" 24 | ] 25 | 26 | [project.optional-dependencies] 27 | train = ["deepspeed==0.12.6", "ninja", "wandb"] 28 | build = ["build", "twine"] 29 | 30 | [project.urls] 31 | "Homepage" = "https://chatgarment.github.io/" 32 | "Bug Tracker" = "https://github.com/biansy000/ChatGarment/issues" 33 | 34 | [tool.setuptools.packages.find] 35 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 36 | 37 | [tool.wheel] 38 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 39 | -------------------------------------------------------------------------------- /run_garmentcode_sim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | import json 5 | from pathlib import Path 6 | 7 | # add the path of GarmentCode 8 | sys.path.insert(1, '/is/cluster/fast/sbian/github/GarmentCodeV2/') 9 | from assets.garment_programs.meta_garment import MetaGarment 10 | from assets.bodies.body_params import BodyParameters 11 | 12 | def run_simultion_warp(pattern_spec, sim_config, output_path, easy_texture_path): 13 | from pygarment.meshgen.boxmeshgen import BoxMesh 14 | from pygarment.meshgen.simulation import run_sim 15 | import pygarment.data_config as data_config 16 | from pygarment.meshgen.sim_config import PathCofig 17 | 18 | props = data_config.Properties(sim_config) 19 | props.set_section_stats('sim', fails={}, sim_time={}, spf={}, fin_frame={}, body_collisions={}, self_collisions={}) 20 | props.set_section_stats('render', render_time={}) 21 | 22 | spec_path = Path(pattern_spec) 23 | garment_name, _, _ = spec_path.stem.rpartition('_') # assuming ending in '_specification' 24 | 25 | paths = PathCofig( 26 | in_element_path=spec_path.parent, 27 | out_path=output_path, 28 | in_name=garment_name, 29 | body_name='mean_all', # 'f_smpl_average_A40' 30 | smpl_body=False, # NOTE: depends on chosen body model 31 | add_timestamp=False, 32 | system_path='/is/cluster/fast/sbian/github/GarmentCodeV2/system.json', 33 | easy_texture_path=easy_texture_path 34 | ) 35 | 36 | # Generate and save garment box mesh (if not existent) 37 | print(f"Generate box mesh of {garment_name} with resolution {props['sim']['config']['resolution_scale']}...") 38 | print('\nGarment load: ', paths.in_g_spec) 39 | 40 | garment_box_mesh = BoxMesh(paths.in_g_spec, props['sim']['config']['resolution_scale']) 41 | garment_box_mesh.load() 42 | garment_box_mesh.serialize( 43 | paths, store_panels=False, uv_config=props['render']['config']['uv_texture']) 44 | 45 | props.serialize(paths.element_sim_props) 46 | 47 | run_sim( 48 | garment_box_mesh.name, 49 | props, 50 | paths, 51 | save_v_norms=False, 52 | store_usd=False, # NOTE: False for fast simulation! 53 | optimize_storage=False, # props['sim']['config']['optimize_storage'], 54 | verbose=False 55 | ) 56 | 57 | props.serialize(paths.element_sim_props) 58 | 59 | 60 | parser = argparse.ArgumentParser() 61 | parser.add_argument("--all_paths_json", type=str, default='', help="path to the save resules shapenet dataset") 62 | parser.add_argument("--json_spec_file", type=str, default='', help="path to the save resules shapenet dataset") 63 | parser.add_argument("--easy_texture_path", type=str, default='', help="path to the save resules shapenet dataset") 64 | args = parser.parse_args() 65 | 66 | if len(args.all_paths_json) > 1: 67 | garment_json_path = os.path.join(args.all_paths_json, 'vis_new/all_json_spec_files.json') 68 | 69 | with open(garment_json_path) as f: 70 | garment_json_paths = json.load(f) 71 | 72 | elif args.json_spec_file: 73 | garment_json_paths = [args.json_spec_file] 74 | 75 | print(len(garment_json_paths)) 76 | for json_spec_file in garment_json_paths: 77 | print(json_spec_file) 78 | json_spec_file = json_spec_file.replace('validate_garment', 'valid_garment') 79 | saved_folder = os.path.dirname(json_spec_file) 80 | run_simultion_warp( 81 | json_spec_file, 82 | 'assets/Sim_props/default_sim_props.yaml', 83 | saved_folder, 84 | easy_texture_path=args.easy_texture_path 85 | ) 86 | -------------------------------------------------------------------------------- /scripts/v1_5/evaluate_garment_v2_demo_edit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export LD_LIBRARY_PATH=/is/software/nvidia/cuda-12.1/lib64 4 | export PATH=$PATH:/is/software/nvidia/cuda-12.1/bin 5 | export CUDA_HOME=/is/software/nvidia/cuda-12.1 6 | 7 | export CPATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include 8 | export C_INCLUDE_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include 9 | export LIBRARY_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/lib64 10 | export LD_LIBRARY_PATH=$LIBRARY_PATH:$LD_LIBRARY_PATH 11 | 12 | export EGL_DEVICE_ID=$GPU_DEVICE_ORDINAL 13 | # export TCNN_CUDA_ARCHITECTURES=80 14 | 15 | deepspeed scripts/evaluate_garment_v2_demo_edit_1float.py \ 16 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 17 | --deepspeed ./scripts/zero2.json \ 18 | --model_name_or_path liuhaotian/llava-v1.5-7b \ 19 | --version v1 \ 20 | --data_path ./ \ 21 | --data_path_eval example_data/example_jsons/example_edit_prompts.json \ 22 | --image_folder ./ \ 23 | --vision_tower openai/clip-vit-large-patch14-336 \ 24 | --mm_projector_type mlp2x_gelu \ 25 | --mm_vision_select_layer -2 \ 26 | --mm_use_im_start_end False \ 27 | --mm_use_im_patch_token False \ 28 | --image_aspect_ratio pad \ 29 | --group_by_modality_length True \ 30 | --bf16 True \ 31 | --output_dir ./checkpoints/llava-v1.5-7b-task-lora \ 32 | --num_train_epochs 1 \ 33 | --per_device_train_batch_size 16 \ 34 | --per_device_eval_batch_size 4 \ 35 | --gradient_accumulation_steps 1 \ 36 | --evaluation_strategy "no" \ 37 | --save_strategy "steps" \ 38 | --save_steps 50000 \ 39 | --save_total_limit 1 \ 40 | --learning_rate 2e-4 \ 41 | --weight_decay 0. \ 42 | --warmup_ratio 0.03 \ 43 | --lr_scheduler_type "cosine" \ 44 | --logging_steps 1 \ 45 | --tf32 True \ 46 | --model_max_length 3072 \ 47 | --gradient_checkpointing True \ 48 | --dataloader_num_workers 4 \ 49 | --lazy_preprocess True \ 50 | --report_to wandb 51 | 52 | -------------------------------------------------------------------------------- /scripts/v1_5/evaluate_garment_v2_eva_edit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export LD_LIBRARY_PATH=/is/software/nvidia/cuda-12.1/lib64 4 | export PATH=$PATH:/is/software/nvidia/cuda-12.1/bin 5 | export CUDA_HOME=/is/software/nvidia/cuda-12.1 6 | 7 | export CPATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include 8 | export C_INCLUDE_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include 9 | export LIBRARY_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/lib64 10 | export LD_LIBRARY_PATH=$LIBRARY_PATH:$LD_LIBRARY_PATH 11 | 12 | export EGL_DEVICE_ID=$GPU_DEVICE_ORDINAL 13 | # export TCNN_CUDA_ARCHITECTURES=80 14 | 15 | deepspeed scripts/evaluate_garment_v2_eva_edit_1float.py \ 16 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 17 | --deepspeed ./scripts/zero2.json \ 18 | --model_name_or_path liuhaotian/llava-v1.5-7b \ 19 | --version v1 \ 20 | --data_path ./ \ 21 | --data_path_eval data/llava_preprocess.json \ 22 | --image_folder ./ \ 23 | --vision_tower openai/clip-vit-large-patch14-336 \ 24 | --mm_projector_type mlp2x_gelu \ 25 | --mm_vision_select_layer -2 \ 26 | --mm_use_im_start_end False \ 27 | --mm_use_im_patch_token False \ 28 | --image_aspect_ratio pad \ 29 | --group_by_modality_length True \ 30 | --bf16 True \ 31 | --output_dir ./checkpoints/llava-v1.5-7b-task-lora \ 32 | --num_train_epochs 1 \ 33 | --per_device_train_batch_size 16 \ 34 | --per_device_eval_batch_size 4 \ 35 | --gradient_accumulation_steps 1 \ 36 | --evaluation_strategy "no" \ 37 | --save_strategy "steps" \ 38 | --save_steps 50000 \ 39 | --save_total_limit 1 \ 40 | --learning_rate 2e-4 \ 41 | --weight_decay 0. \ 42 | --warmup_ratio 0.03 \ 43 | --lr_scheduler_type "cosine" \ 44 | --logging_steps 1 \ 45 | --tf32 True \ 46 | --model_max_length 3072 \ 47 | --gradient_checkpointing True \ 48 | --dataloader_num_workers 4 \ 49 | --lazy_preprocess True \ 50 | --report_to wandb 51 | 52 | -------------------------------------------------------------------------------- /scripts/v1_5/evaluate_garment_v2_imggen_2step.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export LD_LIBRARY_PATH=/is/software/nvidia/cuda-12.1/lib64 4 | export PATH=$PATH:/is/software/nvidia/cuda-12.1/bin 5 | export CUDA_HOME=/is/software/nvidia/cuda-12.1 6 | 7 | export CPATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include 8 | export C_INCLUDE_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include 9 | export LIBRARY_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/lib64 10 | export LD_LIBRARY_PATH=$LIBRARY_PATH:$LD_LIBRARY_PATH 11 | 12 | export EGL_DEVICE_ID=$GPU_DEVICE_ORDINAL 13 | # export TCNN_CUDA_ARCHITECTURES=80 14 | 15 | deepspeed scripts/evaluate_garment_v2_imggen_1float.py \ 16 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 17 | --deepspeed ./scripts/zero2.json \ 18 | --model_name_or_path liuhaotian/llava-v1.5-7b \ 19 | --version v1 \ 20 | --data_path ./ \ 21 | --data_path_eval $1 \ 22 | --image_folder ./ \ 23 | --vision_tower openai/clip-vit-large-patch14-336 \ 24 | --mm_projector_type mlp2x_gelu \ 25 | --mm_vision_select_layer -2 \ 26 | --mm_use_im_start_end False \ 27 | --mm_use_im_patch_token False \ 28 | --image_aspect_ratio pad \ 29 | --group_by_modality_length True \ 30 | --bf16 True \ 31 | --output_dir ./checkpoints/llava-v1.5-7b-task-lora \ 32 | --num_train_epochs 1 \ 33 | --per_device_train_batch_size 16 \ 34 | --per_device_eval_batch_size 4 \ 35 | --gradient_accumulation_steps 1 \ 36 | --evaluation_strategy "no" \ 37 | --save_strategy "steps" \ 38 | --save_steps 50000 \ 39 | --save_total_limit 1 \ 40 | --learning_rate 2e-4 \ 41 | --weight_decay 0. \ 42 | --warmup_ratio 0.03 \ 43 | --lr_scheduler_type "cosine" \ 44 | --logging_steps 1 \ 45 | --tf32 True \ 46 | --model_max_length 3072 \ 47 | --gradient_checkpointing True \ 48 | --dataloader_num_workers 4 \ 49 | --lazy_preprocess True \ 50 | --report_to wandb 51 | 52 | -------------------------------------------------------------------------------- /scripts/v1_5/evaluate_garment_v2_textgen.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export LD_LIBRARY_PATH=/is/software/nvidia/cuda-12.1/lib64 4 | export PATH=$PATH:/is/software/nvidia/cuda-12.1/bin 5 | export CUDA_HOME=/is/software/nvidia/cuda-12.1 6 | 7 | export CPATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include 8 | export C_INCLUDE_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include 9 | export LIBRARY_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/lib64 10 | export LD_LIBRARY_PATH=$LIBRARY_PATH:$LD_LIBRARY_PATH 11 | 12 | export EGL_DEVICE_ID=$GPU_DEVICE_ORDINAL 13 | # export TCNN_CUDA_ARCHITECTURES=80 14 | 15 | deepspeed scripts/evaluate_garment_v2_textgen_1float.py \ 16 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 17 | --deepspeed ./scripts/zero2.json \ 18 | --model_name_or_path liuhaotian/llava-v1.5-7b \ 19 | --version v1 \ 20 | --data_path ./ \ 21 | --data_path_eval $1 \ 22 | --image_folder ./ \ 23 | --vision_tower openai/clip-vit-large-patch14-336 \ 24 | --mm_projector_type mlp2x_gelu \ 25 | --mm_vision_select_layer -2 \ 26 | --mm_use_im_start_end False \ 27 | --mm_use_im_patch_token False \ 28 | --image_aspect_ratio pad \ 29 | --group_by_modality_length True \ 30 | --bf16 True \ 31 | --output_dir ./checkpoints/llava-v1.5-7b-task-lora \ 32 | --num_train_epochs 1 \ 33 | --per_device_train_batch_size 16 \ 34 | --per_device_eval_batch_size 4 \ 35 | --gradient_accumulation_steps 1 \ 36 | --evaluation_strategy "no" \ 37 | --save_strategy "steps" \ 38 | --save_steps 50000 \ 39 | --save_total_limit 1 \ 40 | --learning_rate 2e-4 \ 41 | --weight_decay 0. \ 42 | --warmup_ratio 0.03 \ 43 | --lr_scheduler_type "cosine" \ 44 | --logging_steps 1 \ 45 | --tf32 True \ 46 | --model_max_length 3072 \ 47 | --gradient_checkpointing True \ 48 | --dataloader_num_workers 4 \ 49 | --lazy_preprocess True \ 50 | --report_to wandb 51 | 52 | -------------------------------------------------------------------------------- /scripts/v1_5/evaluate_garment_v2_textgen_fromimg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export LD_LIBRARY_PATH=/is/software/nvidia/cuda-12.1/lib64 4 | export PATH=$PATH:/is/software/nvidia/cuda-12.1/bin 5 | export CUDA_HOME=/is/software/nvidia/cuda-12.1 6 | 7 | export CPATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include 8 | export C_INCLUDE_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include 9 | export LIBRARY_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/lib64 10 | export LD_LIBRARY_PATH=$LIBRARY_PATH:$LD_LIBRARY_PATH 11 | 12 | export EGL_DEVICE_ID=$GPU_DEVICE_ORDINAL 13 | # export TCNN_CUDA_ARCHITECTURES=80 14 | 15 | deepspeed scripts/evaluate_garment_v2_textgen_fromimg_1float.py \ 16 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 17 | --deepspeed ./scripts/zero2.json \ 18 | --model_name_or_path liuhaotian/llava-v1.5-7b \ 19 | --version v1 \ 20 | --data_path ./ \ 21 | --data_path_eval $1 \ 22 | --image_folder ./ \ 23 | --vision_tower openai/clip-vit-large-patch14-336 \ 24 | --mm_projector_type mlp2x_gelu \ 25 | --mm_vision_select_layer -2 \ 26 | --mm_use_im_start_end False \ 27 | --mm_use_im_patch_token False \ 28 | --image_aspect_ratio pad \ 29 | --group_by_modality_length True \ 30 | --bf16 True \ 31 | --output_dir ./checkpoints/llava-v1.5-7b-task-lora \ 32 | --num_train_epochs 1 \ 33 | --per_device_train_batch_size 16 \ 34 | --per_device_eval_batch_size 4 \ 35 | --gradient_accumulation_steps 1 \ 36 | --evaluation_strategy "no" \ 37 | --save_strategy "steps" \ 38 | --save_steps 50000 \ 39 | --save_total_limit 1 \ 40 | --learning_rate 2e-4 \ 41 | --weight_decay 0. \ 42 | --warmup_ratio 0.03 \ 43 | --lr_scheduler_type "cosine" \ 44 | --logging_steps 1 \ 45 | --tf32 True \ 46 | --model_max_length 3072 \ 47 | --gradient_checkpointing True \ 48 | --dataloader_num_workers 4 \ 49 | --lazy_preprocess True \ 50 | --report_to wandb 51 | 52 | -------------------------------------------------------------------------------- /scripts/v1_5/finetune_task_lora_garmentcode_outfit.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export LD_LIBRARY_PATH=/is/software/nvidia/cuda-12.1/lib64 4 | export PATH=$PATH:/is/software/nvidia/cuda-12.1/bin 5 | export CUDA_HOME=/is/software/nvidia/cuda-12.1 6 | 7 | export CPATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include 8 | export C_INCLUDE_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/include 9 | export LIBRARY_PATH=/is/software/nvidia/cudnn-8.4.1-cu11.6/lib64 10 | export LD_LIBRARY_PATH=$LIBRARY_PATH:$LD_LIBRARY_PATH 11 | 12 | export EGL_DEVICE_ID=$GPU_DEVICE_ORDINAL 13 | # export TCNN_CUDA_ARCHITECTURES=80 14 | 15 | deepspeed llava/train/train_mem_garmentcode_outfit.py \ 16 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 17 | --deepspeed ./scripts/zero2.json \ 18 | --model_name_or_path liuhaotian/llava-v1.5-7b \ 19 | --version v1 \ 20 | --data_path ./ \ 21 | --data_path_eval ./ \ 22 | --image_folder ./ \ 23 | --vision_tower openai/clip-vit-large-patch14-336 \ 24 | --mm_projector_type mlp2x_gelu \ 25 | --mm_vision_select_layer -2 \ 26 | --mm_use_im_start_end False \ 27 | --mm_use_im_patch_token False \ 28 | --image_aspect_ratio pad \ 29 | --group_by_modality_length True \ 30 | --bf16 True \ 31 | --output_dir ./checkpoints/llava-v1.5-7b-task-lora \ 32 | --num_train_epochs 1 \ 33 | --per_device_train_batch_size 16 \ 34 | --per_device_eval_batch_size 4 \ 35 | --gradient_accumulation_steps 1 \ 36 | --evaluation_strategy "no" \ 37 | --save_strategy "steps" \ 38 | --save_steps 50000 \ 39 | --save_total_limit 1 \ 40 | --learning_rate 2e-4 \ 41 | --weight_decay 0. \ 42 | --warmup_ratio 0.03 \ 43 | --lr_scheduler_type "cosine" \ 44 | --logging_steps 1 \ 45 | --tf32 True \ 46 | --model_max_length 3072 \ 47 | --gradient_checkpointing True \ 48 | --dataloader_num_workers 4 \ 49 | --lazy_preprocess True \ 50 | --report_to wandb 51 | 52 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } --------------------------------------------------------------------------------