├── README.md ├── __init__.py ├── constants.py ├── conversation_v01.py ├── convert_from_ckpt.py ├── convert_original_stable_diffusion_to_diffusers.py ├── data ├── ConversationTemplateEditing_use.txt ├── LLMSD_InstructDiffusion_color.txt ├── LLMSD_InstructDiffusion_seg.txt ├── __init__.py ├── adapter_config.json └── conv_template_cap_to_img.txt ├── dataset ├── EditMLLMSD_dataset.py ├── LLaVAMLLMSD_dataset.py ├── ReasonEditMLLMSD_dataset.py ├── ReasonSegMLLMSD_dataset.py ├── SegMLLMSD_dataset.py └── __init__.py ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── eval │ ├── eval_gpt_review.py │ ├── eval_gpt_review_bench.py │ ├── eval_gpt_review_visual.py │ ├── eval_science_qa.py │ ├── eval_science_qa_gpt4.py │ ├── eval_science_qa_gpt4_requery.py │ ├── generate_webpage_data_from_table.py │ ├── model_qa.py │ ├── model_vqa.py │ ├── model_vqa_science.py │ ├── qa_baseline_gpt35.py │ ├── run_llava.py │ ├── summarize_gpt_review.py │ ├── table │ │ ├── answer │ │ │ ├── answer_alpaca-13b.jsonl │ │ │ ├── answer_bard.jsonl │ │ │ ├── answer_gpt35.jsonl │ │ │ ├── answer_llama-13b.jsonl │ │ │ └── answer_vicuna-13b.jsonl │ │ ├── caps_boxes_coco2014_val_80.jsonl │ │ ├── model.jsonl │ │ ├── prompt.jsonl │ │ ├── question.jsonl │ │ ├── results │ │ │ └── test_sqa_llava_13b_v0.json │ │ ├── review │ │ │ ├── review_alpaca-13b_vicuna-13b.jsonl │ │ │ ├── review_bard_vicuna-13b.jsonl │ │ │ ├── review_gpt35_vicuna-13b.jsonl │ │ │ └── review_llama-13b_vicuna-13b.jsonl │ │ ├── reviewer.jsonl │ │ └── rule.json │ └── webpage │ │ ├── figures │ │ ├── alpaca.png │ │ ├── bard.jpg │ │ ├── chatgpt.svg │ │ ├── llama.jpg │ │ ├── swords_FILL0_wght300_GRAD0_opsz48.svg │ │ └── vicuna.jpeg │ │ ├── index.html │ │ ├── script.js │ │ └── styles.css ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_llama.py │ │ ├── llava_mpt.py │ │ └── mpt │ │ │ ├── adapt_tokenizer.py │ │ │ ├── attention.py │ │ │ ├── blocks.py │ │ │ ├── configuration_mpt.py │ │ │ ├── custom_embedding.py │ │ │ ├── flash_attn_triton.py │ │ │ ├── hf_prefixlm_converter.py │ │ │ ├── meta_init_context.py │ │ │ ├── modeling_mpt.py │ │ │ ├── norm.py │ │ │ └── param_init_fns.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ └── utils.py ├── serve │ ├── __init__.py │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── extreme_ironing.jpg │ │ └── waterview.jpg │ ├── gradio_web_server.py │ ├── model_worker.py │ ├── register_worker.py │ └── test_message.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ └── train_mem.py └── utils.py ├── model ├── DS_LoraLLaMAUnetPeftModel_new.py ├── DS_MLLMSD11_model.py ├── DS_SmartEdit_model.py ├── LLMSD_QFormerv01.py ├── LLMSD_modelv01_conv.py ├── __init__.py ├── two_way_transformer.py └── unet_2d_condition_ZeroConv.py ├── process_HF.py ├── requirements.txt ├── scripts ├── MLLMSD_13b.sh ├── MLLMSD_7b.sh ├── SmartEdit_13b.sh ├── SmartEdit_7b.sh ├── TrainStage1_13b.sh ├── TrainStage1_7b.sh ├── zero2_mixed.json └── zero2_offload_mixed.json ├── test ├── DS_MLLMSD11_test.py ├── DS_PeftForLoRA.py ├── DS_SmartEdit_test.py ├── InstructPix2PixSD_SM.py ├── SDPipeIP2P_variant1.py ├── TrainStage1_inference.py └── metrics_evaluation.py ├── train ├── DS_MLLMSD11_train.py ├── DS_SmartEdit_train.py ├── TrainStage1.py ├── __init__.py └── llama_flash_attn_monkey_patch.py └── utils.py /__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.2.11" 2 | -------------------------------------------------------------------------------- /constants.py: -------------------------------------------------------------------------------- 1 | from enum import IntEnum 2 | import os 3 | 4 | # For the gradio web server 5 | SERVER_ERROR_MSG = ( 6 | "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 7 | ) 8 | MODERATION_MSG = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE FIX YOUR INPUT AND TRY AGAIN." 9 | CONVERSATION_LIMIT_MSG = "YOU HAVE REACHED THE CONVERSATION LENGTH LIMIT. PLEASE CLEAR HISTORY AND START A NEW CONVERSATION." 10 | INPUT_CHAR_LEN_LIMIT = 2560 11 | CONVERSATION_LEN_LIMIT = 50 12 | LOGDIR = "." 13 | 14 | # For the controller and workers(could be overwritten through ENV variables.) 15 | CONTROLLER_HEART_BEAT_EXPIRATION = int( 16 | os.getenv("FASTCHAT_CONTROLLER_HEART_BEAT_EXPIRATION", 90) 17 | ) 18 | WORKER_HEART_BEAT_INTERVAL = int(os.getenv("FASTCHAT_WORKER_HEART_BEAT_INTERVAL", 30)) 19 | WORKER_API_TIMEOUT = int(os.getenv("FASTCHAT_WORKER_API_TIMEOUT", 100)) 20 | WORKER_API_EMBEDDING_BATCH_SIZE = int(os.getenv("WORKER_API_EMBEDDING_BATCH_SIZE", 4)) 21 | 22 | 23 | class ErrorCode(IntEnum): 24 | """ 25 | https://platform.openai.com/docs/guides/error-codes/api-errors 26 | """ 27 | 28 | VALIDATION_TYPE_ERROR = 40001 29 | 30 | INVALID_AUTH_KEY = 40101 31 | INCORRECT_AUTH_KEY = 40102 32 | NO_PERMISSION = 40103 33 | 34 | INVALID_MODEL = 40301 35 | PARAM_OUT_OF_RANGE = 40302 36 | CONTEXT_OVERFLOW = 40303 37 | 38 | RATE_LIMIT = 42901 39 | QUOTA_EXCEEDED = 42902 40 | ENGINE_OVERLOADED = 42903 41 | 42 | INTERNAL_ERROR = 50001 43 | CUDA_OUT_OF_MEMORY = 50002 44 | GRADIO_REQUEST_ERROR = 50003 45 | GRADIO_STREAM_UNKNOWN_ERROR = 50004 46 | CONTROLLER_NO_WORKER = 50005 47 | CONTROLLER_WORKER_TIMEOUT = 50006 48 | -------------------------------------------------------------------------------- /convert_original_stable_diffusion_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Conversion script for the LDM checkpoints. """ 16 | 17 | """ 18 | pip install diffusers==0.20.2 19 | python convert_original_stable_diffusion_to_diffusers.py --checkpoint_path "./InstructDiffusion/v1-5-pruned-emaonly-adaption-task.ckpt" --original_config_file "./InstructDiffusion/configs/instruct_diffusion.yaml" --dump_path "./InstructDiffusion_diffusers" 20 | """ 21 | 22 | import argparse 23 | import importlib 24 | import torch 25 | from convert_from_ckpt import download_from_original_stable_diffusion_ckpt 26 | # from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser() 30 | 31 | parser.add_argument( 32 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 33 | ) 34 | # !wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml 35 | parser.add_argument( 36 | "--original_config_file", 37 | default=None, 38 | type=str, 39 | help="The YAML config file corresponding to the original architecture.", 40 | ) 41 | parser.add_argument( 42 | "--num_in_channels", 43 | default=None, 44 | type=int, 45 | help="The number of input channels. If `None` number of input channels will be automatically inferred.", 46 | ) 47 | parser.add_argument( 48 | "--scheduler_type", 49 | default="pndm", 50 | type=str, 51 | help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim', 'euler', 'euler-ancestral', 'dpm']", 52 | ) 53 | parser.add_argument( 54 | "--pipeline_type", 55 | default=None, 56 | type=str, 57 | help=( 58 | "The pipeline type. One of 'FrozenOpenCLIPEmbedder', 'FrozenCLIPEmbedder', 'PaintByExample'" 59 | ". If `None` pipeline will be automatically inferred." 60 | ), 61 | ) 62 | parser.add_argument( 63 | "--image_size", 64 | default=None, 65 | type=int, 66 | help=( 67 | "The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2" 68 | " Base. Use 768 for Stable Diffusion v2." 69 | ), 70 | ) 71 | parser.add_argument( 72 | "--prediction_type", 73 | default=None, 74 | type=str, 75 | help=( 76 | "The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable" 77 | " Diffusion v2 Base. Use 'v_prediction' for Stable Diffusion v2." 78 | ), 79 | ) 80 | parser.add_argument( 81 | "--extract_ema", 82 | action="store_true", 83 | help=( 84 | "Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights" 85 | " or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield" 86 | " higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning." 87 | ), 88 | ) 89 | parser.add_argument( 90 | "--upcast_attention", 91 | action="store_true", 92 | help=( 93 | "Whether the attention computation should always be upcasted. This is necessary when running stable" 94 | " diffusion 2.1." 95 | ), 96 | ) 97 | parser.add_argument( 98 | "--from_safetensors", 99 | action="store_true", 100 | help="If `--checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.", 101 | ) 102 | parser.add_argument( 103 | "--to_safetensors", 104 | action="store_true", 105 | help="Whether to store pipeline in safetensors format or not.", 106 | ) 107 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 108 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 109 | parser.add_argument( 110 | "--stable_unclip", 111 | type=str, 112 | default=None, 113 | required=False, 114 | help="Set if this is a stable unCLIP model. One of 'txt2img' or 'img2img'.", 115 | ) 116 | parser.add_argument( 117 | "--stable_unclip_prior", 118 | type=str, 119 | default=None, 120 | required=False, 121 | help="Set if this is a stable unCLIP txt2img model. Selects which prior to use. If `--stable_unclip` is set to `txt2img`, the karlo prior (https://huggingface.co/kakaobrain/karlo-v1-alpha/tree/main/prior) is selected by default.", 122 | ) 123 | parser.add_argument( 124 | "--clip_stats_path", 125 | type=str, 126 | help="Path to the clip stats file. Only required if the stable unclip model's config specifies `model.params.noise_aug_config.params.clip_stats_path`.", 127 | required=False, 128 | ) 129 | parser.add_argument( 130 | "--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint." 131 | ) 132 | parser.add_argument("--half", action="store_true", help="Save weights in half precision.") 133 | parser.add_argument( 134 | "--vae_path", 135 | type=str, 136 | default=None, 137 | required=False, 138 | help="Set to a path, hub id to an already converted vae to not convert it again.", 139 | ) 140 | parser.add_argument( 141 | "--pipeline_class_name", 142 | type=str, 143 | default=None, 144 | required=False, 145 | help="Specify the pipeline class name", 146 | ) 147 | 148 | args = parser.parse_args() 149 | 150 | if args.pipeline_class_name is not None: 151 | library = importlib.import_module("diffusers") 152 | class_obj = getattr(library, args.pipeline_class_name) 153 | pipeline_class = class_obj 154 | else: 155 | pipeline_class = None 156 | 157 | pipe = download_from_original_stable_diffusion_ckpt( 158 | checkpoint_path=args.checkpoint_path, 159 | original_config_file=args.original_config_file, 160 | image_size=args.image_size, 161 | prediction_type=args.prediction_type, 162 | model_type=args.pipeline_type, 163 | extract_ema=args.extract_ema, 164 | scheduler_type=args.scheduler_type, 165 | num_in_channels=args.num_in_channels, 166 | upcast_attention=args.upcast_attention, 167 | from_safetensors=args.from_safetensors, 168 | device=args.device, 169 | stable_unclip=args.stable_unclip, 170 | stable_unclip_prior=args.stable_unclip_prior, 171 | clip_stats_path=args.clip_stats_path, 172 | controlnet=args.controlnet, 173 | vae_path=args.vae_path, 174 | pipeline_class=pipeline_class, 175 | ) 176 | 177 | if args.half: 178 | pipe.to(torch_dtype=torch.float16) 179 | 180 | if args.controlnet: 181 | # only save the controlnet model 182 | pipe.controlnet.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 183 | else: 184 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) -------------------------------------------------------------------------------- /data/ConversationTemplateEditing_use.txt: -------------------------------------------------------------------------------- 1 | ############################################################################################################################# 2 | 1. 3 | Human: Could you produce an image based on the previous image and text description [cap]? 4 | GPT: Certainly, I can create the image you requested [img]. 5 | 6 | 2. 7 | Human: Please generate an image that corresponds to the previous image and text description [cap]. 8 | GPT: Sure, I can generate the image you're looking for [img]. 9 | 10 | 3. 11 | Human: Can you create an image that represents the previous image and text description [cap]? 12 | GPT: Of course, I have created the image you requested [img]. 13 | 14 | 4. 15 | Human: I'm looking for a new image that matches the conditional image and caption [cap]. 16 | GPT: Understood, the image you need is [img]. 17 | 18 | 5. 19 | Human: Please provide me with an image that corresponds to the conditional image and caption [cap]. 20 | GPT: Sure, here is the image that corresponds to the conditional image and caption [img]. 21 | 22 | 6. 23 | Human: Could you create an image that matches the conditional image and caption [cap]? 24 | GPT: Absolutely, I have created the image that matches the conditional image and caption [img]. 25 | 26 | 7. 27 | Human: Can you make a visual based on the prior image and caption [cap]? 28 | GPT: Yes, I can make the visual you requested [img]. 29 | 30 | 8. 31 | Human: Please generate a visual that corresponds to the prior image and caption [cap]. 32 | GPT: Of course, I can generate the visual you're looking for [img]. 33 | 34 | 9. 35 | Human: I need an image that represents the prior image and caption [cap]. Can you create it? 36 | GPT: Certainly, I have created the image that represents the prior image and caption [img]. 37 | 38 | 10. 39 | Human: Please provide me with an image that depicts the former image and concept [cap]. 40 | GPT: Sure, here is the image that depicts the former image and concept [img]. 41 | 42 | 11. 43 | Human: Can you create an image that represents the former image and concept [cap]? 44 | GPT: Absolutely, I have created the image that represents the former image and concept [img]. 45 | 46 | 12. 47 | Human: I'm looking for an image that captures the former image and concept [cap]. Can you make it? 48 | GPT: Yes, I have created the image that captures the former image and concept [img]. 49 | 50 | 13. 51 | Human: Could you create a graphic that matches the earlier visual and description [cap]? 52 | GPT: Of course, I can create the matching graphic you're looking for [img]. 53 | 54 | 14. 55 | Human: Please generate a graphic that corresponds to the earlier visual and description [cap]. 56 | GPT: Certainly, I can generate the graphic you requested [img]. 57 | 58 | 15. 59 | Human: I need a graphic that goes with the earlier visual and description [cap]. Can you create it? 60 | GPT: Yes, I have created the graphic that goes with the earlier visual and description [img]. 61 | 62 | 16. 63 | Human: Can you design an image according to the older visual and idea [cap]? 64 | GPT: Yes, I can design the image you requested [img]. 65 | 66 | 17. 67 | Human: Please create an image that corresponds to the older visual and idea [cap]. 68 | GPT: Sure, I can create the image you're looking for [img]. 69 | 70 | 18. 71 | Human: I need an image that represents the older visual and idea [cap]. Can you design it? 72 | GPT: Absolutely, I have designed the image that represents the older visual and idea [img]. 73 | 74 | 19. 75 | Human: Please create a visual representation using the past graphic and theme [cap]. 76 | GPT: Understood, the visual representation you requested is [img]. 77 | 78 | 20. 79 | Human: Can you generate a visual representation that uses the past graphic and theme [cap]? 80 | GPT: Yes, I have generated the visual representation you requested [img]. 81 | 82 | 21. 83 | Human: I'm looking for a visual representation that uses the past graphic and theme [cap]. Can you create it? 84 | GPT: Certainly, I have created the visual representation you requested [img]. 85 | 86 | 22. 87 | Human: Could you create an image that conveys the previous visual and notion of [cap]? 88 | GPT: Of course, I have created the image that conveys the previous visual and notion you requested [img]. 89 | 90 | 23. 91 | Human: Please generate an image that represents the previous visual and notion of [cap]. 92 | GPT: Yes, I can generate the image you're looking for [img]. 93 | 94 | 24. 95 | Human: I need an image that embodies the previous visual and notion of [cap]. Can you create it? 96 | GPT: Certainly, I have created the image that embodies the previous visual and notion of [img]. 97 | 98 | 25. 99 | Human: Can you create an image that embodies the preceding picture and idea [cap]? 100 | GPT: Certainly, I have created the image that embodies the preceding picture and idea [img]. 101 | 102 | 26. 103 | Human: I'm looking for an image that captures the preceding picture and idea [cap]. Can you generate it? 104 | GPT: Yes, I have generated the image that captures the preceding picture and idea [img]. 105 | 106 | 27. 107 | Human: Please provide me with an image that represents the preceding picture and idea [cap]. 108 | GPT: Of course, here is the image that represents the preceding picture and idea [img]. 109 | 110 | 28. 111 | Human: Could you design a visual that captures the former picture and message [cap]? 112 | GPT: Yes, I can design the visual that captures the former picture and message [img]. 113 | 114 | 29. 115 | Human: Please generate a visual that represents the former picture and message [cap]. 116 | GPT: Certainly, I can generate the visual you're looking for [img]. 117 | 118 | 30. 119 | Human: I need a visual that embodies the former picture and message [cap]. Can you design it? 120 | GPT: Of course, I have designed the visual that embodies the former picture and message [img]. 121 | -------------------------------------------------------------------------------- /data/LLMSD_InstructDiffusion_color.txt: -------------------------------------------------------------------------------- 1 | Red 纯红 #FF0000 255,0,0 2 | 3 | Purple 紫色 #800080 128,0,128 4 | 5 | Blue 纯蓝 #0000FF 0,0,255 6 | 7 | Green 纯绿 #008000 0,128,0 8 | 9 | Yellow 纯黄 #FFFF00 255,255,0 10 | 11 | White 纯白 #FFFFFF 255,255,255 12 | 13 | Black 纯黑 #000000 0,0,0 14 | 15 | Gray 灰色 #808080 128,128,128 16 | 17 | 18 | -------------------------------------------------------------------------------- /data/LLMSD_InstructDiffusion_seg.txt: -------------------------------------------------------------------------------- 1 | Mark the pixels of {object} in {color} and leave the rest unchanged. 2 | Color the {object}'s pixels in {color}, keeping the remaining pixels unaltered. 3 | Apply {color} to the pixels of {object} while maintaining the current state of other pixels. 4 | Assign {color} to the pixels belonging to {object}, preserving the rest as they are. 5 | For {object}, set its pixels to {color} and let the others remain the same. 6 | Modify the pixels of {object} to {color} without affecting any other pixels. 7 | Set the {object} pixels to {color} and keep the other pixels in their original state. 8 | Update the pixels of {object} to {color}, but leave the other pixels untouched. 9 | Fill in the pixels of {object} with {color}, retaining the existing colors of the remaining pixels. 10 | Change the {object} pixels to {color}, while keeping the other pixels constant. 11 | Paint the pixels of {object} in {color} and maintain the current appearance of the other pixels. -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SmartEdit/a65f1262dfcba68c138ea95fe9936df1bd2c111d/data/__init__.py -------------------------------------------------------------------------------- /data/adapter_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "auto_mapping": null, 3 | "base_model_name_or_path": "vicuna", 4 | "bias": "none", 5 | "fan_in_fan_out": false, 6 | "inference_mode": true, 7 | "init_lora_weights": true, 8 | "layers_pattern": null, 9 | "layers_to_transform": null, 10 | "lora_alpha": 16, 11 | "lora_dropout": 0.05, 12 | "modules_to_save": null, 13 | "peft_type": "LORA", 14 | "r": 8, 15 | "revision": null, 16 | "target_modules": [ 17 | "q_proj", 18 | "v_proj" 19 | ], 20 | "task_type": "CAUSAL_LM" 21 | } -------------------------------------------------------------------------------- /data/conv_template_cap_to_img.txt: -------------------------------------------------------------------------------- 1 | 1. 2 | human: Can you create a picture based on the description [cap]? 3 | gpt: Absolutely, the resulting picture is [img]. 4 | 5 | 2. 6 | human: I need an image that matches the caption [cap]. 7 | gpt: Of course, the image you're looking for is [img]. 8 | 9 | 3. 10 | human: Please make a visual based on this caption [cap]. 11 | gpt: No problem, the visual you requested is [img]. 12 | 13 | 4. 14 | human: I'd like an image that represents the concept [cap]. 15 | gpt: Sure thing, the image depicting the concept is [img]. 16 | 17 | 5. 18 | human: Can you produce a picture that illustrates [cap]? 19 | gpt: Certainly, the picture illustrating that is [img]. 20 | 21 | 6. 22 | human: I need a graphic that goes with the description [cap]. 23 | gpt: Alright, the matching graphic is [img]. 24 | 25 | 7. 26 | human: Please design an image according to the idea [cap]. 27 | gpt: You got it, the designed image is [img]. 28 | 29 | 8. 30 | human: Can you create a visual representation with the theme [cap]? 31 | gpt: Sure, the visual representation is [img]. 32 | 33 | 9. 34 | human: I'm looking for an image that conveys the notion of [cap]. 35 | gpt: Understood, the image that conveys the notion is [img]. 36 | 37 | 10. 38 | human: Could you generate an image that embodies the idea [cap]? 39 | gpt: Yes, the image that embodies the idea is [img]. 40 | 41 | 11. 42 | human: I'd like you to develop an image based on the idea [cap]. 43 | gpt: Sure, the image that captures the idea is [img]. 44 | 45 | 12. 46 | human: Please create an image that conveys the message [cap]. 47 | gpt: Certainly, the image that expresses the message is [img]. 48 | 49 | 13. 50 | human: Can you generate a visual inspired by the words [cap]? 51 | gpt: Of course, the inspired visual is [img]. 52 | 53 | 14. 54 | human: I require a picture that exemplifies the text [cap]. 55 | gpt: No problem, the picture that represents the text is [img]. 56 | 57 | 15. 58 | human: Make an image that illustrates the concept [cap]. 59 | gpt: All right, the image that demonstrates the concept is [img]. 60 | 61 | 16. 62 | human: Please provide a graphic that embodies the idea [cap]. 63 | gpt: Absolutely, the graphic that personifies the idea is [img]. 64 | 65 | 17. 66 | human: Create a visual depiction based on [cap]. 67 | gpt: Definitely, the visual depiction based on [cap] is [img]. 68 | 69 | 18. 70 | human: I'd like an illustration that captures the essence of [cap]. 71 | gpt: Sure thing, the illustration that portrays the essence is [img]. 72 | 73 | 19. 74 | human: Can you come up with a visual interpretation of the phrase [cap]? 75 | gpt: Yes, the visual interpretation of the phrase is [img]. 76 | 77 | 20. 78 | human: Design an image that reflects the meaning of [cap]. 79 | gpt: Of course, the image that mirrors the meaning is [img]. 80 | 81 | 21. 82 | human: Please generate an artwork that encompasses the theme [cap]. 83 | gpt: No problem, the artwork that includes the theme is [img]. 84 | 85 | 22. 86 | human: I need a picture that symbolizes the caption [cap]. 87 | gpt: All right, the picture that signifies the caption is [img]. 88 | 89 | 23. 90 | human: Produce a graphic that encapsulates the concept [cap]. 91 | gpt: Certainly, the graphic that summarizes the concept is [img]. 92 | 93 | 24. 94 | human: Can you create a visual representation for the description [cap]? 95 | gpt: Absolutely, the visual representation for the description is [img]. 96 | 97 | 25. 98 | human: Make an illustration that defines the context of [cap]. 99 | gpt: Definitely, the illustration that represents the context is [img]. 100 | 101 | 26. 102 | human: I'd like a picture that showcases the idea [cap]. 103 | gpt: Sure thing, the picture that highlights the idea is [img]. 104 | 105 | 27. 106 | human: Create a visual that translates the thought [cap]. 107 | gpt: Of course, the visual that conveys the thought is [img]. 108 | 109 | 28. 110 | human: Please design a graphic inspired by the expression [cap]. 111 | gpt: Yes, the graphic inspired by the expression is [img]. 112 | 113 | 29. 114 | human: Generate an image that visually represents the word(s) [cap]. 115 | gpt: No problem, the image that corresponds to the word(s) is [img]. 116 | 117 | 30. 118 | human: I need an illustration that brings to life the concept [cap]. 119 | gpt: All right, the illustration that vividly portrays the concept is [img]. -------------------------------------------------------------------------------- /dataset/LLaVAMLLMSD_dataset.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import os 3 | import torch 4 | from torch.utils.data import Dataset 5 | from transformers.trainer_pt_utils import LabelSmoother 6 | from conversation_v01 import SeparatorStyle, get_conv_template 7 | from PIL import Image 8 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 9 | 10 | import json 11 | import copy 12 | DEFAULT_IMAGE_TOKEN = '' 13 | IGNORE_INDEX = -100 14 | 15 | def tokenizer_image_token_(prompt, tokenizer, image_token_index, return_tensors=None): 16 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 17 | # len(prompt_chunks)=2 18 | 19 | def insert_separator(X, sep): 20 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 21 | 22 | input_ids = [] 23 | offset = 0 24 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 25 | offset = 1 26 | input_ids.append(prompt_chunks[0][0]) 27 | 28 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 29 | input_ids.extend(x[offset:]) 30 | 31 | # return_tensors or not 32 | if return_tensors is not None: 33 | if return_tensors == 'pt': 34 | input_ids = torch.tensor(input_ids, dtype=torch.long) 35 | return input_ids 36 | else: 37 | return input_ids 38 | 39 | # LLaVA dataset 40 | class LLaVADataset_for_instruction_tuning(Dataset): 41 | """ LLAVA-Dataset for instruction tuning """ 42 | def __init__(self, 43 | data_path, 44 | image_folder, 45 | LLM_tokenizer, 46 | CLIPImageProcessor, 47 | is_LLaMA, 48 | LLaVADataset_resolution_ViT 49 | ): 50 | super(LLaVADataset_for_instruction_tuning, self).__init__() 51 | # LLaVA dataset 52 | list_data_dict = json.load(open(data_path, "r")) 53 | self.list_data_dict = list_data_dict 54 | self.image_folder = image_folder 55 | 56 | # LLM tokenizer 57 | self.LLM_tokenizer = LLM_tokenizer 58 | self.is_LLaMA = is_LLaMA 59 | 60 | # CLIPImageProcessor 61 | self.CLIPImageProcessor = CLIPImageProcessor 62 | self.LLaVADataset_resolution_ViT = LLaVADataset_resolution_ViT 63 | 64 | def __len__(self): 65 | return len(self.list_data_dict) 66 | 67 | def __getitem__(self, i): 68 | sources = self.list_data_dict[i] 69 | sources = [sources] 70 | assert len(sources) == 1, "Don't know why it is wrapped to a list" 71 | 72 | # 1. image -> [3, 224, 224] 73 | image_file = self.list_data_dict[i]['image'] 74 | image_folder = self.image_folder 75 | image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') 76 | image = image.resize((self.LLaVADataset_resolution_ViT, self.LLaVADataset_resolution_ViT), resample=Image.Resampling.BICUBIC) 77 | image = self.CLIPImageProcessor.preprocess(image, return_tensors='pt')['pixel_values'] 78 | image = image[0] 79 | 80 | # 2. preprocess_multimodal function 81 | # DEFAULT_IMAGE_TOKEN='', DEFAULT_IM_START_TOKEN='', DEFAULT_IM_END_TOKEN='' 82 | sources = copy.deepcopy([e["conversations"] for e in sources]) 83 | for source in sources: 84 | for sentence in source: 85 | if DEFAULT_IMAGE_TOKEN in sentence['value']: 86 | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() 87 | sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] 88 | sentence['value'] = sentence['value'].strip() 89 | # '\nWhat are the colors of the bus in the image?' 90 | replace_token = DEFAULT_IMAGE_TOKEN 91 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) 92 | # ' \nWhat are the colors of the bus in the image?' 93 | 94 | # 3. preprocess function 95 | # Step-1: choose conversation system message 96 | assert ('image' in self.list_data_dict[i]) == True 97 | conv = get_conv_template("vicuna_v1.3") 98 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 99 | # A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. -> {'human': 'USER', 'gpt': 'ASSISTANT'} 100 | 101 | # Step-2: Apply prompt templates for conversation 102 | conversations = [] 103 | for i, source in enumerate(sources): 104 | if roles[source[0]["from"]] != conv.roles[0]: 105 | # Skip the first one if it is not from human 106 | source = source[1:] 107 | 108 | conv.messages = [] 109 | for j, sentence in enumerate(source): 110 | role = roles[sentence["from"]] 111 | assert role == conv.roles[j % 2], f"{i}" 112 | conv.append_message(role, sentence["value"]) 113 | conversations.append(conv.get_prompt()) 114 | 115 | # data processing for LLaMA 116 | input_ids_for_LLM, targets_for_LLM = None, None 117 | if self.is_LLaMA == True: 118 | # Step-3: Tokenize conversations 119 | input_ids_for_LLM = torch.stack([tokenizer_image_token_(prompt, self.LLM_tokenizer, 120 | image_token_index=self.LLM_tokenizer.img_start_token_id, 121 | return_tensors='pt') for prompt in conversations], dim=0) 122 | targets_for_LLM = input_ids_for_LLM.clone() 123 | assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO 124 | 125 | # Step-4: Mask targets 126 | sep = conv.sep + conv.roles[1] + ": " 127 | for conversation, target_for_LLM in zip(conversations, targets_for_LLM): 128 | total_len = int(target_for_LLM.ne(self.LLM_tokenizer.pad_token_id).sum()) 129 | rounds = conversation.split(conv.sep2) 130 | cur_len = 1 131 | target_for_LLM[:cur_len] = IGNORE_INDEX 132 | for i, rou in enumerate(rounds): 133 | if rou == "": 134 | break 135 | 136 | parts = rou.split(sep) 137 | if len(parts) != 2: 138 | break 139 | parts[0] += sep 140 | 141 | round_len = len(tokenizer_image_token_(rou, self.LLM_tokenizer, image_token_index=self.LLM_tokenizer.img_start_token_id)) 142 | instruction_len = len(tokenizer_image_token_(parts[0], self.LLM_tokenizer, image_token_index=self.LLM_tokenizer.img_start_token_id)) - 2 143 | 144 | target_for_LLM[cur_len: cur_len + instruction_len] = IGNORE_INDEX 145 | cur_len += round_len 146 | target_for_LLM[cur_len:] = IGNORE_INDEX 147 | 148 | if cur_len < self.LLM_tokenizer.model_max_length: 149 | if cur_len != total_len: 150 | target_for_LLM[:] = IGNORE_INDEX 151 | print(f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." f" (ignored)") 152 | 153 | # return dataloader -> 'original_img_for_vae' and 'edited_img' are placeholders 154 | original_img = image 155 | original_img_for_vae = torch.zeros([3, 256, 256], dtype=torch.float32) 156 | edited_img = torch.zeros([3, 256, 256], dtype=torch.float32) 157 | 158 | # For input_ids -> insert '' and '' 159 | input_ids_ = input_ids_for_LLM[0] 160 | LLM_img_start_token_id = self.LLM_tokenizer.img_start_token_id 161 | LLM_img_start_token_id_pos = (torch.where(input_ids_ == LLM_img_start_token_id)[0])[0].item() 162 | new_input_ids_ = torch.cat([input_ids_[:LLM_img_start_token_id_pos], 163 | torch.tensor([self.LLM_tokenizer.DEFAULT_IM_START_TOKEN]), 164 | input_ids_[LLM_img_start_token_id_pos:(LLM_img_start_token_id_pos + 1)], 165 | torch.tensor([self.LLM_tokenizer.DEFAULT_IM_END_TOKEN]), 166 | input_ids_[(LLM_img_start_token_id_pos + 1):]], dim=0) 167 | input_attention_mask = new_input_ids_.ne(self.LLM_tokenizer.pad_token_id) 168 | 169 | # For generated_caption_targets -> insert 2*IGNORE_INDEX 170 | generated_caption_targets = torch.cat([torch.tensor([IGNORE_INDEX]), torch.tensor([IGNORE_INDEX]), 171 | targets_for_LLM[0]], dim=0) 172 | generated_caption_encoder_attention_mask = new_input_ids_.ge(self.LLM_tokenizer.img_start_token_id) 173 | 174 | # task choosing 175 | is_editing_task = torch.zeros(1) 176 | 177 | # LLaVA-Dataset dataloader 178 | return {'original_img': original_img, 179 | 'original_img_for_vae': original_img_for_vae, 180 | 'edited_img': edited_img, 181 | 'input_ids': new_input_ids_, 182 | 'input_attention_mask': input_attention_mask, 183 | 'generated_caption_targets': generated_caption_targets, 184 | 'generated_caption_encoder_attention_mask': generated_caption_encoder_attention_mask, 185 | 'is_editing_task': is_editing_task} 186 | -------------------------------------------------------------------------------- /dataset/ReasonEditMLLMSD_dataset.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import copy 4 | import json 5 | import numpy as np 6 | from conversation_v01 import SeparatorStyle, get_conv_template 7 | from PIL import Image 8 | import random 9 | import torch 10 | from torch.utils.data import Dataset 11 | from transformers.trainer_pt_utils import LabelSmoother 12 | IGNORE_TOKEN_ID = LabelSmoother.ignore_index 13 | 14 | def convert_to_np(image, resolution): 15 | image = image.convert("RGB") 16 | image = image.resize((resolution, resolution), resample=Image.Resampling.BICUBIC) 17 | return np.array(image).transpose(2, 0, 1) 18 | 19 | # ReasoningEditing dataset 20 | class ReasoningEditing_Dataset(Dataset): 21 | def __init__(self, 22 | ReasoningEditingDataset_path, 23 | ReasoningEditingDataset_resolution_ViT, 24 | ReasoningEditingDataset_resolution_for_SD, 25 | CLIPImageProcessor, 26 | mm_projection_length, 27 | editing_template, 28 | editing_max_length, 29 | llm_tokenizer=None 30 | ): 31 | 32 | # ReasoningEditing Dataset path 33 | with open(ReasoningEditingDataset_path, 'r') as f: 34 | self.ReasoningEditing_data = json.load(f) 35 | 36 | # 224, 256 37 | self.ReasoningEditingDataset_resolution_ViT = ReasoningEditingDataset_resolution_ViT 38 | self.ReasoningEditingDataset_resolution_for_SD = ReasoningEditingDataset_resolution_for_SD 39 | 40 | # CLIPImageProcessor -> 没有flip操作 41 | self.CLIPImageProcessor = CLIPImageProcessor 42 | 43 | # LLM tokenizer 44 | self.llm_tokenizer = llm_tokenizer 45 | self.llm_tokenizer.padding_side = "right" 46 | self.llm_tokenizer.truncation_side = 'right' 47 | 48 | # Vicuna conversation system for editing 49 | self.editing_template = editing_template 50 | self.editing_max_length = editing_max_length 51 | self.mm_projection_length = mm_projection_length 52 | 53 | def __len__(self,): 54 | return len(self.ReasoningEditing_data) 55 | 56 | def __getitem__(self, index): 57 | # load variables from json file 58 | key = f'{index:04d}' 59 | original_img_path = self.ReasoningEditing_data[key]['origin_img_path'] 60 | original_image = Image.open(original_img_path).convert('RGB') 61 | target_img_path = self.ReasoningEditing_data[key]['target_img_path'] 62 | target_image = Image.open(target_img_path).convert('RGB') 63 | 64 | # random select an instruction 65 | instruction_list = self.ReasoningEditing_data[key]['instruction'] 66 | instruction = random.choice(instruction_list) 67 | 68 | # 1. Original Image for ViT input 69 | RE_original_image = copy.deepcopy(original_image) 70 | RE_original_image = RE_original_image.resize((self.ReasoningEditingDataset_resolution_ViT, self.ReasoningEditingDataset_resolution_ViT), 71 | resample=Image.Resampling.BICUBIC) 72 | RE_original_image = self.CLIPImageProcessor.preprocess(RE_original_image, return_tensors='pt')['pixel_values'] 73 | RE_original_image = RE_original_image[0] 74 | 75 | # 2. Original Image & 3. Edited Image for SD input 76 | RE_original_image_2 = convert_to_np(original_image, self.ReasoningEditingDataset_resolution_for_SD) 77 | RE_target_image = convert_to_np(target_image, self.ReasoningEditingDataset_resolution_for_SD) 78 | RE_SD_input = np.concatenate([RE_original_image_2, RE_target_image]) 79 | RE_SD_input = torch.tensor(RE_SD_input) 80 | RE_SD_input = 2 * (RE_SD_input / 255) - 1 81 | RE_original_image_2, RE_target_image = RE_SD_input.chunk(2) 82 | 83 | #################################################################################### 84 | # Vicuna conversation system construction for image editing task... 85 | # Step 1. Choose Human-GPT templates 86 | conversation_templates = [] 87 | with open(self.editing_template, 'r') as f: 88 | lines = f.readlines() 89 | for line in lines: 90 | line = line.strip() 91 | if line.startswith('Human: '): 92 | d = dict() 93 | d['Human'] = line[len("Human: "):] 94 | conversation_templates.append(d) 95 | elif line.startswith('GPT: '): 96 | conversation_templates[-1]['GPT'] = line[len("GPT: "):] 97 | 98 | # Step 2. Choose Vicuna_v1.3 system message 99 | conv = get_conv_template("vicuna_v1.3") 100 | roles = {"Human": conv.roles[0], "GPT": conv.roles[1]} 101 | 102 | # tokens -> num_new_tokens=35: ""(system message) + " ... " 103 | num_new_tokens = len(self.llm_tokenizer) - self.llm_tokenizer.vocab_size 104 | append_str = "" 105 | for i in range(num_new_tokens - 3): 106 | append_str += f" " 107 | 108 | # Step 3. Vicuna conversation system construction 109 | """ "" is a placeholder to find the text position and insert image embeddings """ 110 | edited_prompt = instruction 111 | DEFAULT_IM_START_TOKEN = '' 112 | DEFAULT_IM_END_TOKEN = '' 113 | edited_prompt = DEFAULT_IM_START_TOKEN + f" " + DEFAULT_IM_END_TOKEN + edited_prompt 114 | conversation_template = random.choice(conversation_templates) 115 | conv.messages = [] 116 | conv.append_message(roles["Human"], conversation_template["Human"].replace('[cap]', f'"{edited_prompt}"')) 117 | conv.append_message(roles["GPT"], conversation_template["GPT"].replace(' [img].', append_str)) 118 | conversation = conv.get_prompt() 119 | conversation = conversation.replace("\n", "") 120 | 121 | # 4. Edited Prompt input_ids -> Tokenize conversations 122 | input_ids_max_len = self.editing_max_length - self.mm_projection_length 123 | input_ids = self.llm_tokenizer( 124 | conversation, 125 | return_tensors="pt", 126 | padding="max_length", 127 | max_length=input_ids_max_len, 128 | truncation=True, 129 | ).input_ids[0] 130 | # [(editing_max_length-mm_projection_length)=256] 131 | 132 | # Step 4. Only show up tokens after 'ASSISTANT:' 133 | # IGNORE_TOKEN_ID=-100 134 | generated_caption_targets = input_ids.clone() 135 | sep = conv.sep + conv.roles[1] + ": " 136 | generated_caption_targets[:1] = IGNORE_TOKEN_ID 137 | total_padding_len = int(generated_caption_targets.ne(self.llm_tokenizer.pad_token_id).sum()) 138 | parts = conversation.split(sep) 139 | parts[0] += sep 140 | 141 | # 5. Generated caption targets for Language Model loss 142 | instruction_len = len( 143 | self.llm_tokenizer( 144 | parts[0], 145 | max_length=input_ids_max_len, 146 | truncation=True, 147 | ).input_ids) - 2 148 | generated_caption_targets[1:(1 + instruction_len)] = IGNORE_TOKEN_ID 149 | generated_caption_targets[total_padding_len:] = IGNORE_TOKEN_ID 150 | # [(editing_max_length-mm_projection_length)=256] 151 | #################################################################################### 152 | 153 | # 6. Edited Prompt attention_mask 154 | # ne(a, b) is a != b 155 | RE_instruction_attention_mask = input_ids.ne(self.llm_tokenizer.pad_token_id) 156 | 157 | # 7. Generated caption targets attention mask 158 | # ge(a, b) is a >= b 159 | generated_caption_encoder_attention_mask = input_ids.ge(self.llm_tokenizer.img_start_token_id) 160 | 161 | # 8. task choosing 162 | is_editing_task = torch.ones(1) 163 | 164 | # Reasoning-Editing dataloader -> 3 parts -> [bs, 3, 224, 224] + [bs, 3, 256, 256], [bs, 3, 256, 256], ['let the asparagus be replaced with sausages'] 165 | return {'original_img': RE_original_image, 166 | 'original_img_for_vae': RE_original_image_2, 167 | 'edited_img': RE_target_image, 168 | 'input_ids': input_ids, 169 | 'input_attention_mask': RE_instruction_attention_mask, 170 | 'generated_caption_targets': generated_caption_targets, 171 | 'generated_caption_encoder_attention_mask': generated_caption_encoder_attention_mask, 172 | 'is_editing_task': is_editing_task} 173 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SmartEdit/a65f1262dfcba68c138ea95fe9936df1bd2c111d/dataset/__init__.py -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import tqdm 7 | import ray 8 | import time 9 | 10 | NUM_SECONDS_TO_SLEEP = 3 11 | 12 | @ray.remote(num_cpus=4) 13 | def get_eval(content: str, max_tokens: int): 14 | while True: 15 | try: 16 | response = openai.ChatCompletion.create( 17 | model='gpt-4', 18 | messages=[{ 19 | 'role': 'system', 20 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 21 | }, { 22 | 'role': 'user', 23 | 'content': content, 24 | }], 25 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 26 | max_tokens=max_tokens, 27 | ) 28 | break 29 | except openai.error.RateLimitError: 30 | pass 31 | except Exception as e: 32 | print(e) 33 | time.sleep(NUM_SECONDS_TO_SLEEP) 34 | 35 | print('success!') 36 | return response['choices'][0]['message']['content'] 37 | 38 | 39 | def parse_score(review): 40 | try: 41 | score_pair = review.split('\n')[0] 42 | score_pair = score_pair.replace(',', ' ') 43 | sp = score_pair.split(' ') 44 | if len(sp) == 2: 45 | return [float(sp[0]), float(sp[1])] 46 | else: 47 | print('error', review) 48 | return [-1, -1] 49 | except Exception as e: 50 | print(e) 51 | print('error', review) 52 | return [-1, -1] 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 57 | parser.add_argument('-q', '--question') 58 | # parser.add_argument('-a', '--answer') 59 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 60 | parser.add_argument('-r', '--rule') 61 | parser.add_argument('-o', '--output') 62 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 63 | args = parser.parse_args() 64 | 65 | ray.init() 66 | 67 | f_q = open(os.path.expanduser(args.question)) 68 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 69 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 70 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 71 | 72 | review_file = open(f'{args.output}', 'w') 73 | 74 | js_list = [] 75 | handles = [] 76 | idx = 0 77 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 78 | # if idx == 1: 79 | # break 80 | 81 | ques = json.loads(ques_js) 82 | ans1 = json.loads(ans1_js) 83 | ans2 = json.loads(ans2_js) 84 | 85 | category = json.loads(ques_js)['category'] 86 | if category in rule_dict: 87 | rule = rule_dict[category] 88 | else: 89 | rule = rule_dict['default'] 90 | prompt = rule['prompt'] 91 | role = rule['role'] 92 | content = (f'[Question]\n{ques["text"]}\n\n' 93 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 94 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 95 | f'[System]\n{prompt}\n\n') 96 | js_list.append({ 97 | 'id': idx+1, 98 | 'question_id': ques['question_id'], 99 | 'answer1_id': ans1['answer_id'], 100 | 'answer2_id': ans2['answer_id'], 101 | 'category': category}) 102 | idx += 1 103 | handles.append(get_eval.remote(content, args.max_tokens)) 104 | # To avoid the rate limit set by OpenAI 105 | time.sleep(NUM_SECONDS_TO_SLEEP) 106 | 107 | reviews = ray.get(handles) 108 | for idx, review in enumerate(reviews): 109 | scores = parse_score(review) 110 | js_list[idx]['content'] = review 111 | js_list[idx]['tuple'] = scores 112 | review_file.write(json.dumps(js_list[idx]) + '\n') 113 | review_file.close() 114 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_bench.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | cap_str = '\n'.join(inst['caption']) 86 | 87 | category = 'llava_bench_' + json.loads(ques_js)['category'] 88 | if category in rule_dict: 89 | rule = rule_dict[category] 90 | else: 91 | assert False, f"Visual QA category not found in rule file: {category}." 92 | prompt = rule['prompt'] 93 | role = rule['role'] 94 | content = (f'[Context]\n{cap_str}\n\n' 95 | f'[Question]\n{ques["text"]}\n\n' 96 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 97 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 98 | f'[System]\n{prompt}\n\n') 99 | cur_js = { 100 | 'id': idx+1, 101 | 'question_id': ques['question_id'], 102 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 103 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 104 | 'category': category 105 | } 106 | if idx >= len(cur_reviews): 107 | review = get_eval(content, args.max_tokens) 108 | scores = parse_score(review) 109 | cur_js['content'] = review 110 | cur_js['tuple'] = scores 111 | review_file.write(json.dumps(cur_js) + '\n') 112 | review_file.flush() 113 | else: 114 | print(f'Skipping {idx} as we already have it.') 115 | idx += 1 116 | print(idx) 117 | review_file.close() 118 | -------------------------------------------------------------------------------- /llava/eval/eval_gpt_review_visual.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | import openai 6 | import time 7 | 8 | NUM_SECONDS_TO_SLEEP = 0.5 9 | 10 | 11 | def get_eval(content: str, max_tokens: int): 12 | while True: 13 | try: 14 | response = openai.ChatCompletion.create( 15 | model='gpt-4-0314', 16 | messages=[{ 17 | 'role': 'system', 18 | 'content': 'You are a helpful and precise assistant for checking the quality of the answer.' 19 | }, { 20 | 'role': 'user', 21 | 'content': content, 22 | }], 23 | temperature=0.2, # TODO: figure out which temperature is best for evaluation 24 | max_tokens=max_tokens, 25 | ) 26 | break 27 | except openai.error.RateLimitError: 28 | pass 29 | except Exception as e: 30 | print(e) 31 | time.sleep(NUM_SECONDS_TO_SLEEP) 32 | 33 | return response['choices'][0]['message']['content'] 34 | 35 | 36 | def parse_score(review): 37 | try: 38 | score_pair = review.split('\n')[0] 39 | score_pair = score_pair.replace(',', ' ') 40 | sp = score_pair.split(' ') 41 | if len(sp) == 2: 42 | return [float(sp[0]), float(sp[1])] 43 | else: 44 | print('error', review) 45 | return [-1, -1] 46 | except Exception as e: 47 | print(e) 48 | print('error', review) 49 | return [-1, -1] 50 | 51 | 52 | if __name__ == '__main__': 53 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 54 | parser.add_argument('-q', '--question') 55 | parser.add_argument('-c', '--context') 56 | parser.add_argument('-a', '--answer-list', nargs='+', default=[]) 57 | parser.add_argument('-r', '--rule') 58 | parser.add_argument('-o', '--output') 59 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 60 | args = parser.parse_args() 61 | 62 | f_q = open(os.path.expanduser(args.question)) 63 | f_ans1 = open(os.path.expanduser(args.answer_list[0])) 64 | f_ans2 = open(os.path.expanduser(args.answer_list[1])) 65 | rule_dict = json.load(open(os.path.expanduser(args.rule), 'r')) 66 | 67 | if os.path.isfile(os.path.expanduser(args.output)): 68 | cur_reviews = [json.loads(line) for line in open(os.path.expanduser(args.output))] 69 | else: 70 | cur_reviews = [] 71 | 72 | review_file = open(f'{args.output}', 'a') 73 | 74 | context_list = [json.loads(line) for line in open(os.path.expanduser(args.context))] 75 | image_to_context = {context['image']: context for context in context_list} 76 | 77 | handles = [] 78 | idx = 0 79 | for ques_js, ans1_js, ans2_js in zip(f_q, f_ans1, f_ans2): 80 | ques = json.loads(ques_js) 81 | ans1 = json.loads(ans1_js) 82 | ans2 = json.loads(ans2_js) 83 | 84 | inst = image_to_context[ques['image']] 85 | cap_str = '\n'.join(inst['captions']) 86 | box_str = '\n'.join([f'{instance["category"]}: {instance["bbox"]}' for instance in inst['instances']]) 87 | 88 | category = json.loads(ques_js)['category'] 89 | if category in rule_dict: 90 | rule = rule_dict[category] 91 | else: 92 | assert False, f"Visual QA category not found in rule file: {category}." 93 | prompt = rule['prompt'] 94 | role = rule['role'] 95 | content = (f'[Context]\n{cap_str}\n\n{box_str}\n\n' 96 | f'[Question]\n{ques["text"]}\n\n' 97 | f'[{role} 1]\n{ans1["text"]}\n\n[End of {role} 1]\n\n' 98 | f'[{role} 2]\n{ans2["text"]}\n\n[End of {role} 2]\n\n' 99 | f'[System]\n{prompt}\n\n') 100 | cur_js = { 101 | 'id': idx+1, 102 | 'question_id': ques['question_id'], 103 | 'answer1_id': ans1.get('answer_id', ans1['question_id']), 104 | 'answer2_id': ans2.get('answer_id', ans2['answer_id']), 105 | 'category': category 106 | } 107 | if idx >= len(cur_reviews): 108 | review = get_eval(content, args.max_tokens) 109 | scores = parse_score(review) 110 | cur_js['content'] = review 111 | cur_js['tuple'] = scores 112 | review_file.write(json.dumps(cur_js) + '\n') 113 | review_file.flush() 114 | else: 115 | print(f'Skipping {idx} as we already have it.') 116 | idx += 1 117 | print(idx) 118 | review_file.close() 119 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | 7 | 8 | def get_args(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--base-dir', type=str) 11 | parser.add_argument('--result-file', type=str) 12 | parser.add_argument('--output-file', type=str) 13 | parser.add_argument('--output-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | predictions = [json.loads(line) for line in open(args.result_file)] 45 | predictions = {pred['question_id']: pred for pred in predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | results = {'correct': [], 'incorrect': []} 49 | sqa_results = {} 50 | sqa_results['acc'] = None 51 | sqa_results['correct'] = None 52 | sqa_results['count'] = None 53 | sqa_results['results'] = {} 54 | sqa_results['outputs'] = {} 55 | 56 | for prob_id, prob in split_problems.items(): 57 | if prob_id not in predictions: 58 | continue 59 | pred = predictions[prob_id] 60 | pred_text = pred['text'] 61 | 62 | pattern = re.compile(r'The answer is ([A-Z]).') 63 | res = pattern.findall(pred_text) 64 | if len(res) == 1: 65 | answer = res[0] # 'A', 'B', ... 66 | else: 67 | answer = "FAILED" 68 | 69 | pred_idx = get_pred_idx(answer, prob['choices'], args.options) 70 | 71 | analysis = { 72 | 'question_id': prob_id, 73 | 'parsed_ans': answer, 74 | 'ground_truth': args.options[prob['answer']], 75 | 'question': pred['prompt'], 76 | 'pred': pred_text, 77 | 'is_multimodal': '' in pred['prompt'], 78 | } 79 | 80 | sqa_results['results'][prob_id] = get_pred_idx(answer, prob['choices'], args.options) 81 | sqa_results['outputs'][prob_id] = pred_text 82 | 83 | if pred_idx == prob['answer']: 84 | results['correct'].append(analysis) 85 | else: 86 | results['incorrect'].append(analysis) 87 | 88 | correct = len(results['correct']) 89 | total = len(results['correct']) + len(results['incorrect']) 90 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 91 | 92 | sqa_results['acc'] = correct / total * 100 93 | sqa_results['correct'] = correct 94 | sqa_results['count'] = total 95 | 96 | with open(args.output_file, 'w') as f: 97 | json.dump(results, f, indent=2) 98 | with open(args.output_result, 'w') as f: 99 | json.dump(sqa_results, f, indent=2) 100 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--our-result', type=str) 14 | parser.add_argument('--split', type=str, default='test') 15 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 16 | return parser.parse_args() 17 | 18 | 19 | def convert_caps(results): 20 | fakecaps = [] 21 | for result in results: 22 | image_id = result['question_id'] 23 | caption = result['text'] 24 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 25 | return fakecaps 26 | 27 | 28 | def get_pred_idx(prediction, choices, options): 29 | """ 30 | Get the index (e.g. 2) from the prediction (e.g. 'C') 31 | """ 32 | if prediction in options[:len(choices)]: 33 | return options.index(prediction) 34 | else: 35 | return random.choice(range(len(choices))) 36 | 37 | 38 | if __name__ == "__main__": 39 | args = get_args() 40 | 41 | base_dir = args.base_dir 42 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 43 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 44 | our_predictions = [json.loads(line) for line in open(args.our_result)] 45 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 46 | split_problems = {idx: problems[idx] for idx in split_indices} 47 | 48 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 49 | 50 | results = defaultdict(lambda: 0) 51 | 52 | for prob_id, prob in split_problems.items(): 53 | if prob_id not in our_predictions: 54 | continue 55 | if prob_id not in gpt4_predictions: 56 | continue 57 | our_pred = our_predictions[prob_id]['text'] 58 | gpt4_pred = gpt4_predictions[prob_id] 59 | 60 | pattern = re.compile(r'The answer is ([A-Z]).') 61 | our_res = pattern.findall(our_pred) 62 | if len(our_res) == 1: 63 | our_answer = our_res[0] # 'A', 'B', ... 64 | else: 65 | our_answer = "FAILED" 66 | gpt4_res = pattern.findall(gpt4_pred) 67 | if len(gpt4_res) == 1: 68 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 69 | else: 70 | gpt4_answer = "FAILED" 71 | 72 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 73 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 74 | 75 | if gpt4_answer == 'FAILED': 76 | results['gpt4_failed'] += 1 77 | # continue 78 | gpt4_pred_idx = our_pred_idx 79 | # if our_pred_idx != prob['answer']: 80 | # print(our_predictions[prob_id]['prompt']) 81 | # print('-----------------') 82 | # print(f'LECTURE: {prob["lecture"]}') 83 | # print(f'SOLUTION: {prob["solution"]}') 84 | # print('=====================') 85 | else: 86 | # continue 87 | pass 88 | # gpt4_pred_idx = our_pred_idx 89 | 90 | if gpt4_pred_idx == prob['answer']: 91 | results['correct'] += 1 92 | else: 93 | results['incorrect'] += 1 94 | 95 | 96 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 97 | results['correct_upperbound'] += 1 98 | 99 | correct = results['correct'] 100 | total = results['correct'] + results['incorrect'] 101 | print(f'Total: {total}, Correct: {correct}, Accuracy: {correct / total * 100:.2f}%') 102 | print(f'Total: {total}, Correct (upper): {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 103 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 104 | 105 | -------------------------------------------------------------------------------- /llava/eval/eval_science_qa_gpt4_requery.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import random 6 | from collections import defaultdict 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--base-dir', type=str) 12 | parser.add_argument('--gpt4-result', type=str) 13 | parser.add_argument('--requery-result', type=str) 14 | parser.add_argument('--our-result', type=str) 15 | parser.add_argument('--output-result', type=str) 16 | parser.add_argument('--split', type=str, default='test') 17 | parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"]) 18 | return parser.parse_args() 19 | 20 | 21 | def convert_caps(results): 22 | fakecaps = [] 23 | for result in results: 24 | image_id = result['question_id'] 25 | caption = result['text'] 26 | fakecaps.append({"image_id": int(image_id), "caption": caption}) 27 | return fakecaps 28 | 29 | 30 | def get_pred_idx(prediction, choices, options): 31 | """ 32 | Get the index (e.g. 2) from the prediction (e.g. 'C') 33 | """ 34 | if prediction in options[:len(choices)]: 35 | return options.index(prediction) 36 | else: 37 | return random.choice(range(len(choices))) 38 | 39 | 40 | if __name__ == "__main__": 41 | args = get_args() 42 | 43 | base_dir = args.base_dir 44 | split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[args.split] 45 | problems = json.load(open(os.path.join(base_dir, "problems.json"))) 46 | our_predictions = [json.loads(line) for line in open(args.our_result)] 47 | our_predictions = {pred['question_id']: pred for pred in our_predictions} 48 | split_problems = {idx: problems[idx] for idx in split_indices} 49 | 50 | requery_predictions = [json.loads(line) for line in open(args.requery_result)] 51 | requery_predictions = {pred['question_id']: pred for pred in requery_predictions} 52 | 53 | gpt4_predictions = json.load(open(args.gpt4_result))['outputs'] 54 | 55 | results = defaultdict(lambda: 0) 56 | 57 | sqa_results = {} 58 | sqa_results['acc'] = None 59 | sqa_results['correct'] = None 60 | sqa_results['count'] = None 61 | sqa_results['results'] = {} 62 | sqa_results['outputs'] = {} 63 | 64 | for prob_id, prob in split_problems.items(): 65 | if prob_id not in our_predictions: 66 | assert False 67 | if prob_id not in gpt4_predictions: 68 | assert False 69 | our_pred = our_predictions[prob_id]['text'] 70 | gpt4_pred = gpt4_predictions[prob_id] 71 | if prob_id not in requery_predictions: 72 | results['missing_requery'] += 1 73 | requery_pred = "MISSING" 74 | else: 75 | requery_pred = requery_predictions[prob_id]['text'] 76 | 77 | pattern = re.compile(r'The answer is ([A-Z]).') 78 | our_res = pattern.findall(our_pred) 79 | if len(our_res) == 1: 80 | our_answer = our_res[0] # 'A', 'B', ... 81 | else: 82 | our_answer = "FAILED" 83 | 84 | requery_res = pattern.findall(requery_pred) 85 | if len(requery_res) == 1: 86 | requery_answer = requery_res[0] # 'A', 'B', ... 87 | else: 88 | requery_answer = "FAILED" 89 | 90 | gpt4_res = pattern.findall(gpt4_pred) 91 | if len(gpt4_res) == 1: 92 | gpt4_answer = gpt4_res[0] # 'A', 'B', ... 93 | else: 94 | gpt4_answer = "FAILED" 95 | 96 | our_pred_idx = get_pred_idx(our_answer, prob['choices'], args.options) 97 | gpt4_pred_idx = get_pred_idx(gpt4_answer, prob['choices'], args.options) 98 | requery_pred_idx = get_pred_idx(requery_answer, prob['choices'], args.options) 99 | 100 | results['total'] += 1 101 | 102 | if gpt4_answer == 'FAILED': 103 | results['gpt4_failed'] += 1 104 | if gpt4_pred_idx == prob['answer']: 105 | results['gpt4_correct'] += 1 106 | if our_pred_idx == prob['answer']: 107 | results['gpt4_ourvisual_correct'] += 1 108 | elif gpt4_pred_idx == prob['answer']: 109 | results['gpt4_correct'] += 1 110 | results['gpt4_ourvisual_correct'] += 1 111 | 112 | if our_pred_idx == prob['answer']: 113 | results['our_correct'] += 1 114 | 115 | if requery_answer == 'FAILED': 116 | sqa_results['results'][prob_id] = our_pred_idx 117 | if our_pred_idx == prob['answer']: 118 | results['requery_correct'] += 1 119 | else: 120 | sqa_results['results'][prob_id] = requery_pred_idx 121 | if requery_pred_idx == prob['answer']: 122 | results['requery_correct'] += 1 123 | else: 124 | print(f""" 125 | Question ({args.options[prob['answer']]}): {our_predictions[prob_id]['prompt']} 126 | Our ({our_answer}): {our_pred} 127 | GPT-4 ({gpt4_answer}): {gpt4_pred} 128 | Requery ({requery_answer}): {requery_pred} 129 | print("=====================================") 130 | """) 131 | 132 | if gpt4_pred_idx == prob['answer'] or our_pred_idx == prob['answer']: 133 | results['correct_upperbound'] += 1 134 | 135 | total = results['total'] 136 | print(f'Total: {total}, Our-Correct: {results["our_correct"]}, Accuracy: {results["our_correct"] / total * 100:.2f}%') 137 | print(f'Total: {total}, GPT-4-Correct: {results["gpt4_correct"]}, Accuracy: {results["gpt4_correct"] / total * 100:.2f}%') 138 | print(f'Total: {total}, GPT-4 NO-ANS (RANDOM): {results["gpt4_failed"]}, Percentage: {results["gpt4_failed"] / total * 100:.2f}%') 139 | print(f'Total: {total}, GPT-4-OursVisual-Correct: {results["gpt4_ourvisual_correct"]}, Accuracy: {results["gpt4_ourvisual_correct"] / total * 100:.2f}%') 140 | print(f'Total: {total}, Requery-Correct: {results["requery_correct"]}, Accuracy: {results["requery_correct"] / total * 100:.2f}%') 141 | print(f'Total: {total}, Correct upper: {results["correct_upperbound"]}, Accuracy: {results["correct_upperbound"] / total * 100:.2f}%') 142 | 143 | sqa_results['acc'] = results["requery_correct"] / total * 100 144 | sqa_results['correct'] = results["requery_correct"] 145 | sqa_results['count'] = total 146 | 147 | with open(args.output_result, 'w') as f: 148 | json.dump(sqa_results, f, indent=2) 149 | 150 | -------------------------------------------------------------------------------- /llava/eval/generate_webpage_data_from_table.py: -------------------------------------------------------------------------------- 1 | """Generate json file for webpage.""" 2 | import json 3 | import os 4 | import re 5 | 6 | # models = ['llama', 'alpaca', 'gpt35', 'bard'] 7 | models = ['vicuna'] 8 | 9 | 10 | def read_jsonl(path: str, key: str=None): 11 | data = [] 12 | with open(os.path.expanduser(path)) as f: 13 | for line in f: 14 | if not line: 15 | continue 16 | data.append(json.loads(line)) 17 | if key is not None: 18 | data.sort(key=lambda x: x[key]) 19 | data = {item[key]: item for item in data} 20 | return data 21 | 22 | 23 | def trim_hanging_lines(s: str, n: int) -> str: 24 | s = s.strip() 25 | for _ in range(n): 26 | s = s.split('\n', 1)[1].strip() 27 | return s 28 | 29 | 30 | if __name__ == '__main__': 31 | questions = read_jsonl('table/question.jsonl', key='question_id') 32 | 33 | # alpaca_answers = read_jsonl('table/answer/answer_alpaca-13b.jsonl', key='question_id') 34 | # bard_answers = read_jsonl('table/answer/answer_bard.jsonl', key='question_id') 35 | # gpt35_answers = read_jsonl('table/answer/answer_gpt35.jsonl', key='question_id') 36 | # llama_answers = read_jsonl('table/answer/answer_llama-13b.jsonl', key='question_id') 37 | vicuna_answers = read_jsonl('table/answer/answer_vicuna-13b.jsonl', key='question_id') 38 | ours_answers = read_jsonl('table/results/llama-13b-hf-alpaca.jsonl', key='question_id') 39 | 40 | review_vicuna = read_jsonl('table/review/review_vicuna-13b_llama-13b-hf-alpaca.jsonl', key='question_id') 41 | # review_alpaca = read_jsonl('table/review/review_alpaca-13b_vicuna-13b.jsonl', key='question_id') 42 | # review_bard = read_jsonl('table/review/review_bard_vicuna-13b.jsonl', key='question_id') 43 | # review_gpt35 = read_jsonl('table/review/review_gpt35_vicuna-13b.jsonl', key='question_id') 44 | # review_llama = read_jsonl('table/review/review_llama-13b_vicuna-13b.jsonl', key='question_id') 45 | 46 | records = [] 47 | for qid in questions.keys(): 48 | r = { 49 | 'id': qid, 50 | 'category': questions[qid]['category'], 51 | 'question': questions[qid]['text'], 52 | 'answers': { 53 | # 'alpaca': alpaca_answers[qid]['text'], 54 | # 'llama': llama_answers[qid]['text'], 55 | # 'bard': bard_answers[qid]['text'], 56 | # 'gpt35': gpt35_answers[qid]['text'], 57 | 'vicuna': vicuna_answers[qid]['text'], 58 | 'ours': ours_answers[qid]['text'], 59 | }, 60 | 'evaluations': { 61 | # 'alpaca': review_alpaca[qid]['text'], 62 | # 'llama': review_llama[qid]['text'], 63 | # 'bard': review_bard[qid]['text'], 64 | 'vicuna': review_vicuna[qid]['content'], 65 | # 'gpt35': review_gpt35[qid]['text'], 66 | }, 67 | 'scores': { 68 | 'vicuna': review_vicuna[qid]['tuple'], 69 | # 'alpaca': review_alpaca[qid]['score'], 70 | # 'llama': review_llama[qid]['score'], 71 | # 'bard': review_bard[qid]['score'], 72 | # 'gpt35': review_gpt35[qid]['score'], 73 | }, 74 | } 75 | 76 | # cleanup data 77 | cleaned_evals = {} 78 | for k, v in r['evaluations'].items(): 79 | v = v.strip() 80 | lines = v.split('\n') 81 | # trim the first line if it's a pair of numbers 82 | if re.match(r'\d+[, ]+\d+', lines[0]): 83 | lines = lines[1:] 84 | v = '\n'.join(lines) 85 | cleaned_evals[k] = v.replace('Assistant 1', "**Assistant 1**").replace('Assistant 2', '**Assistant 2**') 86 | 87 | r['evaluations'] = cleaned_evals 88 | records.append(r) 89 | 90 | # Reorder the records, this is optional 91 | for r in records: 92 | if r['id'] <= 20: 93 | r['id'] += 60 94 | else: 95 | r['id'] -= 20 96 | for r in records: 97 | if r['id'] <= 50: 98 | r['id'] += 10 99 | elif 50 < r['id'] <= 60: 100 | r['id'] -= 50 101 | for r in records: 102 | if r['id'] == 7: 103 | r['id'] = 1 104 | elif r['id'] < 7: 105 | r['id'] += 1 106 | 107 | records.sort(key=lambda x: x['id']) 108 | 109 | # Write to file 110 | with open('webpage/data.json', 'w') as f: 111 | json.dump({'questions': records, 'models': models}, f, indent=2) 112 | -------------------------------------------------------------------------------- /llava/eval/model_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria 3 | import torch 4 | import os 5 | import json 6 | from tqdm import tqdm 7 | import shortuuid 8 | 9 | from llava.conversation import default_conversation 10 | from llava.utils import disable_torch_init 11 | 12 | 13 | # new stopping implementation 14 | class KeywordsStoppingCriteria(StoppingCriteria): 15 | def __init__(self, keywords, tokenizer, input_ids): 16 | self.keywords = keywords 17 | self.tokenizer = tokenizer 18 | self.start_len = None 19 | self.input_ids = input_ids 20 | 21 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 22 | if self.start_len is None: 23 | self.start_len = self.input_ids.shape[1] 24 | else: 25 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 26 | for keyword in self.keywords: 27 | if keyword in outputs: 28 | return True 29 | return False 30 | 31 | 32 | @torch.inference_mode() 33 | def eval_model(model_name, questions_file, answers_file): 34 | # Model 35 | disable_torch_init() 36 | model_name = os.path.expanduser(model_name) 37 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 38 | model = AutoModelForCausalLM.from_pretrained(model_name, 39 | torch_dtype=torch.float16).cuda() 40 | 41 | 42 | ques_file = open(os.path.expanduser(questions_file), "r") 43 | ans_file = open(os.path.expanduser(answers_file), "w") 44 | for i, line in enumerate(tqdm(ques_file)): 45 | idx = json.loads(line)["question_id"] 46 | qs = json.loads(line)["text"] 47 | cat = json.loads(line)["category"] 48 | conv = default_conversation.copy() 49 | conv.append_message(conv.roles[0], qs) 50 | prompt = conv.get_prompt() 51 | inputs = tokenizer([prompt]) 52 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 53 | stopping_criteria = KeywordsStoppingCriteria([conv.sep], tokenizer, input_ids) 54 | output_ids = model.generate( 55 | input_ids, 56 | do_sample=True, 57 | use_cache=True, 58 | temperature=0.7, 59 | max_new_tokens=1024, 60 | stopping_criteria=[stopping_criteria]) 61 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 62 | try: 63 | index = outputs.index(conv.sep, len(prompt)) 64 | except ValueError: 65 | outputs += conv.sep 66 | index = outputs.index(conv.sep, len(prompt)) 67 | 68 | outputs = outputs[len(prompt) + len(conv.roles[1]) + 2:index].strip() 69 | ans_id = shortuuid.uuid() 70 | ans_file.write(json.dumps({"question_id": idx, 71 | "text": outputs, 72 | "answer_id": ans_id, 73 | "model_id": model_name, 74 | "metadata": {}}) + "\n") 75 | ans_file.flush() 76 | ans_file.close() 77 | 78 | if __name__ == "__main__": 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 81 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 82 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 83 | args = parser.parse_args() 84 | 85 | eval_model(args.model_name, args.question_file, args.answers_file) 86 | -------------------------------------------------------------------------------- /llava/eval/model_vqa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = [json.loads(q) for q in open(os.path.expanduser(args.question_file), "r")] 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for line in tqdm(questions): 42 | idx = line["question_id"] 43 | image_file = line["image"] 44 | qs = line["text"] 45 | cur_prompt = qs 46 | if model.config.mm_use_im_start_end: 47 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 48 | else: 49 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 50 | 51 | conv = conv_templates[args.conv_mode].copy() 52 | conv.append_message(conv.roles[0], qs) 53 | conv.append_message(conv.roles[1], None) 54 | prompt = conv.get_prompt() 55 | 56 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 57 | 58 | image = Image.open(os.path.join(args.image_folder, image_file)) 59 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 60 | 61 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 62 | keywords = [stop_str] 63 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 64 | 65 | with torch.inference_mode(): 66 | output_ids = model.generate( 67 | input_ids, 68 | images=image_tensor.unsqueeze(0).half().cuda(), 69 | do_sample=True, 70 | temperature=args.temperature, 71 | top_p=args.top_p, 72 | num_beams=args.num_beams, 73 | # no_repeat_ngram_size=3, 74 | max_new_tokens=1024, 75 | use_cache=True) 76 | 77 | input_token_len = input_ids.shape[1] 78 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 79 | if n_diff_input_output > 0: 80 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 81 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 82 | outputs = outputs.strip() 83 | if outputs.endswith(stop_str): 84 | outputs = outputs[:-len(stop_str)] 85 | outputs = outputs.strip() 86 | 87 | ans_id = shortuuid.uuid() 88 | ans_file.write(json.dumps({"question_id": idx, 89 | "prompt": cur_prompt, 90 | "text": outputs, 91 | "answer_id": ans_id, 92 | "model_id": model_name, 93 | "metadata": {}}) + "\n") 94 | ans_file.flush() 95 | ans_file.close() 96 | 97 | if __name__ == "__main__": 98 | parser = argparse.ArgumentParser() 99 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 100 | parser.add_argument("--model-base", type=str, default=None) 101 | parser.add_argument("--image-folder", type=str, default="") 102 | parser.add_argument("--question-file", type=str, default="tables/question.jsonl") 103 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 104 | parser.add_argument("--conv-mode", type=str, default="llava_v1") 105 | parser.add_argument("--num-chunks", type=int, default=1) 106 | parser.add_argument("--chunk-idx", type=int, default=0) 107 | parser.add_argument("--temperature", type=float, default=0.2) 108 | parser.add_argument("--top_p", type=float, default=None) 109 | parser.add_argument("--num_beams", type=int, default=1) 110 | args = parser.parse_args() 111 | 112 | eval_model(args) 113 | -------------------------------------------------------------------------------- /llava/eval/model_vqa_science.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import json 5 | from tqdm import tqdm 6 | import shortuuid 7 | 8 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 9 | from llava.conversation import conv_templates, SeparatorStyle 10 | from llava.model.builder import load_pretrained_model 11 | from llava.utils import disable_torch_init 12 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 13 | 14 | from PIL import Image 15 | import math 16 | 17 | 18 | def split_list(lst, n): 19 | """Split a list into n (roughly) equal-sized chunks""" 20 | chunk_size = math.ceil(len(lst) / n) # integer division 21 | return [lst[i:i+chunk_size] for i in range(0, len(lst), chunk_size)] 22 | 23 | 24 | def get_chunk(lst, n, k): 25 | chunks = split_list(lst, n) 26 | return chunks[k] 27 | 28 | 29 | def eval_model(args): 30 | # Model 31 | disable_torch_init() 32 | model_path = os.path.expanduser(args.model_path) 33 | model_name = get_model_name_from_path(model_path) 34 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name) 35 | 36 | questions = json.load(open(os.path.expanduser(args.question_file), "r")) 37 | questions = get_chunk(questions, args.num_chunks, args.chunk_idx) 38 | answers_file = os.path.expanduser(args.answers_file) 39 | os.makedirs(os.path.dirname(answers_file), exist_ok=True) 40 | ans_file = open(answers_file, "w") 41 | for i, line in enumerate(tqdm(questions)): 42 | idx = line["id"] 43 | question = line['conversations'][0] 44 | gt_ans = line["conversations"][1] 45 | qs = question['value'].replace('', '').strip() 46 | cur_prompt = qs 47 | 48 | if 'image' in line: 49 | image_file = line["image"] 50 | image = Image.open(os.path.join(args.image_folder, image_file)) 51 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 52 | images = image_tensor.unsqueeze(0).half().cuda() 53 | if getattr(model.config, 'mm_use_im_start_end', False): 54 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 55 | else: 56 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 57 | cur_prompt = '' + '\n' + cur_prompt 58 | else: 59 | images = None 60 | 61 | conv = conv_templates[args.conv_mode].copy() 62 | conv.append_message(conv.roles[0], qs) 63 | conv.append_message(conv.roles[1], None) 64 | prompt = conv.get_prompt() 65 | 66 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 67 | 68 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 69 | keywords = [stop_str] 70 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 71 | 72 | with torch.inference_mode(): 73 | output_ids = model.generate( 74 | input_ids, 75 | images=images, 76 | do_sample=True, 77 | temperature=0.2, 78 | max_new_tokens=1024, 79 | use_cache=True, 80 | stopping_criteria=[stopping_criteria]) 81 | 82 | input_token_len = input_ids.shape[1] 83 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 84 | if n_diff_input_output > 0: 85 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 86 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 87 | outputs = outputs.strip() 88 | if outputs.endswith(stop_str): 89 | outputs = outputs[:-len(stop_str)] 90 | outputs = outputs.strip() 91 | 92 | # prompt for answer 93 | if args.answer_prompter: 94 | outputs_reasoning = outputs 95 | input_ids = tokenizer_image_token(prompt + outputs_reasoning + ' ###\nANSWER:', tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 96 | 97 | with torch.inference_mode(): 98 | output_ids = model.generate( 99 | input_ids, 100 | images=images, 101 | do_sample=True, 102 | temperature=0.2, 103 | max_new_tokens=64, 104 | use_cache=True, 105 | stopping_criteria=[stopping_criteria]) 106 | 107 | input_token_len = input_ids.shape[1] 108 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 109 | if n_diff_input_output > 0: 110 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 111 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 112 | outputs = outputs.strip() 113 | if outputs.endswith(stop_str): 114 | outputs = outputs[:-len(stop_str)] 115 | outputs = outputs.strip() 116 | outputs = outputs_reasoning + '\n The answer is ' + outputs 117 | 118 | ans_id = shortuuid.uuid() 119 | ans_file.write(json.dumps({"question_id": idx, 120 | "prompt": cur_prompt, 121 | "text": outputs, 122 | "answer_id": ans_id, 123 | "model_id": model_name, 124 | "metadata": {}}) + "\n") 125 | ans_file.flush() 126 | ans_file.close() 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 131 | parser.add_argument("--model-base", type=str, default=None) 132 | parser.add_argument("--image-folder", type=str, default="") 133 | parser.add_argument("--question-file", type=str, default="tables/question.json") 134 | parser.add_argument("--answers-file", type=str, default="answer.jsonl") 135 | parser.add_argument("--conv-mode", type=str, default="llava_v0") 136 | parser.add_argument("--num-chunks", type=int, default=1) 137 | parser.add_argument("--chunk-idx", type=int, default=0) 138 | parser.add_argument("--answer-prompter", action="store_true") 139 | args = parser.parse_args() 140 | 141 | eval_model(args) 142 | -------------------------------------------------------------------------------- /llava/eval/qa_baseline_gpt35.py: -------------------------------------------------------------------------------- 1 | """Generate answers with GPT-3.5""" 2 | # Note: you need to be using OpenAI Python v0.27.0 for the code below to work 3 | import argparse 4 | import json 5 | import os 6 | import time 7 | import concurrent.futures 8 | 9 | import openai 10 | import tqdm 11 | import shortuuid 12 | 13 | MODEL = 'gpt-3.5-turbo' 14 | MODEL_ID = 'gpt-3.5-turbo:20230327' 15 | 16 | def get_answer(question_id: int, question: str, max_tokens: int): 17 | ans = { 18 | 'answer_id': shortuuid.uuid(), 19 | 'question_id': question_id, 20 | 'model_id': MODEL_ID, 21 | } 22 | for _ in range(3): 23 | try: 24 | response = openai.ChatCompletion.create( 25 | model=MODEL, 26 | messages=[{ 27 | 'role': 'system', 28 | 'content': 'You are a helpful assistant.' 29 | }, { 30 | 'role': 'user', 31 | 'content': question, 32 | }], 33 | max_tokens=max_tokens, 34 | ) 35 | ans['text'] = response['choices'][0]['message']['content'] 36 | return ans 37 | except Exception as e: 38 | print('[ERROR]', e) 39 | ans['text'] = '#ERROR#' 40 | time.sleep(1) 41 | return ans 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser(description='ChatGPT answer generation.') 46 | parser.add_argument('-q', '--question') 47 | parser.add_argument('-o', '--output') 48 | parser.add_argument('--max-tokens', type=int, default=1024, help='maximum number of tokens produced in the output') 49 | args = parser.parse_args() 50 | 51 | questions_dict = {} 52 | with open(os.path.expanduser(args.question)) as f: 53 | for line in f: 54 | if not line: 55 | continue 56 | q = json.loads(line) 57 | questions_dict[q['question_id']] = q['text'] 58 | 59 | answers = [] 60 | 61 | with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor: 62 | futures = [] 63 | for qid, question in questions_dict.items(): 64 | future = executor.submit(get_answer, qid, question, args.max_tokens) 65 | futures.append(future) 66 | 67 | for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(futures)): 68 | answers.append(future.result()) 69 | 70 | answers.sort(key=lambda x: x['question_id']) 71 | 72 | with open(os.path.expanduser(args.output), 'w') as f: 73 | table = [json.dumps(ans) for ans in answers] 74 | f.write('\n'.join(table)) 75 | -------------------------------------------------------------------------------- /llava/eval/run_llava.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 5 | from llava.conversation import conv_templates, SeparatorStyle 6 | from llava.model.builder import load_pretrained_model 7 | from llava.utils import disable_torch_init 8 | from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 9 | 10 | from PIL import Image 11 | 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | 16 | 17 | def load_image(image_file): 18 | if image_file.startswith('http') or image_file.startswith('https'): 19 | response = requests.get(image_file) 20 | image = Image.open(BytesIO(response.content)).convert('RGB') 21 | else: 22 | image = Image.open(image_file).convert('RGB') 23 | return image 24 | 25 | 26 | def eval_model(args): 27 | # Model 28 | disable_torch_init() 29 | 30 | model_name = get_model_name_from_path(args.model_path) 31 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name) 32 | 33 | qs = args.query 34 | if model.config.mm_use_im_start_end: 35 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs 36 | else: 37 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 38 | 39 | if "v1" in model_name.lower(): 40 | conv_mode = "llava_v1" 41 | elif "mpt" in model_name.lower(): 42 | conv_mode = "mpt" 43 | else: 44 | conv_mode = "llava_v0" 45 | 46 | if args.conv_mode is not None and conv_mode != args.conv_mode: 47 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 48 | else: 49 | args.conv_mode = conv_mode 50 | 51 | conv = conv_templates[args.conv_mode].copy() 52 | conv.append_message(conv.roles[0], qs) 53 | conv.append_message(conv.roles[1], None) 54 | prompt = conv.get_prompt() 55 | 56 | image = load_image(args.image_file) 57 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda() 58 | 59 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 60 | 61 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 62 | keywords = [stop_str] 63 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 64 | 65 | with torch.inference_mode(): 66 | output_ids = model.generate( 67 | input_ids, 68 | images=image_tensor, 69 | do_sample=True, 70 | temperature=0.2, 71 | max_new_tokens=1024, 72 | use_cache=True, 73 | stopping_criteria=[stopping_criteria]) 74 | 75 | input_token_len = input_ids.shape[1] 76 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() 77 | if n_diff_input_output > 0: 78 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') 79 | outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] 80 | outputs = outputs.strip() 81 | if outputs.endswith(stop_str): 82 | outputs = outputs[:-len(stop_str)] 83 | outputs = outputs.strip() 84 | print(outputs) 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser() 88 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 89 | parser.add_argument("--model-base", type=str, default=None) 90 | parser.add_argument("--image-file", type=str, required=True) 91 | parser.add_argument("--query", type=str, required=True) 92 | parser.add_argument("--conv-mode", type=str, default=None) 93 | args = parser.parse_args() 94 | 95 | eval_model(args) 96 | -------------------------------------------------------------------------------- /llava/eval/summarize_gpt_review.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from collections import defaultdict 4 | 5 | import numpy as np 6 | 7 | import argparse 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser(description='ChatGPT-based QA evaluation.') 11 | parser.add_argument('-d', '--dir', default=None) 12 | parser.add_argument('-f', '--files', nargs='*', default=None) 13 | parser.add_argument('-i', '--ignore', nargs='*', default=None) 14 | return parser.parse_args() 15 | 16 | 17 | if __name__ == '__main__': 18 | args = parse_args() 19 | 20 | if args.ignore is not None: 21 | args.ignore = [int(x) for x in args.ignore] 22 | 23 | if args.files is not None and len(args.files) > 0: 24 | review_files = args.files 25 | else: 26 | review_files = [x for x in os.listdir(args.dir) if x.endswith('.jsonl') and (x.startswith('gpt4_text') or x.startswith('reviews_') or x.startswith('review_'))] 27 | 28 | for review_file in sorted(review_files): 29 | config = os.path.basename(review_file).replace('gpt4_text_', '').replace('.jsonl', '') 30 | scores = defaultdict(list) 31 | print(config) 32 | with open(os.path.join(args.dir, review_file) if args.dir is not None else review_file) as f: 33 | for review_str in f: 34 | review = json.loads(review_str) 35 | if args.ignore is not None and review['question_id'] in args.ignore: 36 | continue 37 | if 'category' in review: 38 | scores[review['category']].append(review['tuple']) 39 | scores['all'].append(review['tuple']) 40 | else: 41 | if 'tuple' in review: 42 | scores['all'].append(review['tuple']) 43 | else: 44 | scores['all'].append(review['score']) 45 | for k, v in sorted(scores.items()): 46 | stats = np.asarray(v).mean(0).tolist() 47 | stats = [round(x, 3) for x in stats] 48 | # print(k, stats, round(stats[1]/stats[0]*100, 1)) 49 | print(k, round(stats[1]/stats[0]*100, 1)) 50 | print('=================================') 51 | -------------------------------------------------------------------------------- /llava/eval/table/model.jsonl: -------------------------------------------------------------------------------- 1 | {"model_id": "vicuna-13b:20230322-clean-lang", "model_name": "vicuna-13b", "model_version": "20230322-clean-lang", "model_metadata": "vicuna-13b-20230322-clean-lang"} 2 | {"model_id": "alpaca-13b:v1", "model_name": "alpaca-13b", "model_version": "v1", "model_metadata": "alpaca-13b"} 3 | {"model_id": "llama-13b:v1", "model_name": "llama-13b", "model_version": "v1", "model_metadata": "hf-llama-13b"} 4 | {"model_id": "bard:20230327", "model_name": "bard", "model_version": "20230327", "model_metadata": "Google Bard 20230327"} 5 | {"model_id": "gpt-3.5-turbo:20230327", "model_name": "gpt-3.5-turbo", "model_version": "20230327", "model_metadata": "OpenAI ChatGPT gpt-3.5-turbo Chat Completion"} 6 | -------------------------------------------------------------------------------- /llava/eval/table/prompt.jsonl: -------------------------------------------------------------------------------- 1 | {"prompt_id": 1, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for general questions"} 2 | {"prompt_id": 2, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, "description": "Prompt for coding questions"} 3 | {"prompt_id": 3, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, "description": "Prompt for math questions"} 4 | {"prompt_id": 4, "system_prompt": "You are a helpful and precise assistant for checking the quality of the answer.", "prompt_template": "[Visual Context]\n{context}\n[Question]\n{question}\n\n[Assistant 1]\n{answer_1}\n\n[End of Assistant 1]\n\n[Assistant 2]\n{answer_2}\n\n[End of Assistant 2]\n\n[System]\n{prompt}\n\n", "defaults": {"prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, "description": "Prompt for visual questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/table/reviewer.jsonl: -------------------------------------------------------------------------------- 1 | {"reviewer_id": "gpt-4-0328-default", "prompt_id": 1, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for general questions"} 2 | {"reviewer_id": "gpt-4-0328-coding", "prompt_id": 2, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for coding questions"} 3 | {"reviewer_id": "gpt-4-0328-math", "prompt_id": 3, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 4 | {"reviewer_id": "gpt-4-0417-visual", "prompt_id": 4, "metadata": {"temperature": 0.2, "max_tokens": 1024}, "description": "GPT-4 for math questions"} 5 | -------------------------------------------------------------------------------- /llava/eval/table/rule.json: -------------------------------------------------------------------------------- 1 | { 2 | "coding": {"role": "Assistant", "prompt": "Your task is to evaluate the coding abilities of the above two assistants. They have been asked to implement a program to solve a given problem. Please review their code submissions, paying close attention to their problem-solving approach, code structure, readability, and the inclusion of helpful comments.\n\nPlease ensure that the assistants' submissions:\n\n1. Correctly implement the given problem statement.\n2. Contain accurate and efficient code.\n3. Include clear and concise comments that explain the code's logic and functionality.\n4. Adhere to proper coding standards and best practices.\n\nOnce you have carefully reviewed both submissions, provide detailed feedback on their strengths and weaknesses, along with any suggestions for improvement. You should first output a single line containing two scores on the scale of 1-10 (1: no code/no sense; 10: perfect) for Assistant 1 and 2, respectively. Then give extra comments starting from the next line."}, 3 | "math": {"role": "Assistant", "prompt": "We would like to request your feedback on the mathematical proficiency of two AI assistants regarding the given user question.\nFirstly, please solve the problem independently, without referring to the answers provided by Assistant 1 and Assistant 2.\nAfterward, please examine the problem-solving process of Assistant 1 and Assistant 2 step-by-step to ensure their correctness, identifying any incorrect steps if present. Your evaluation should take into account not only the answer but also the problem-solving steps.\nFinally, please output a Python tuple containing two numerical scores for Assistant 1 and Assistant 2, ranging from 1 to 10, respectively. If applicable, explain the reasons for any variations in their scores and determine which assistant performed better."}, 4 | "default": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above.\nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 5 | "conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 6 | "detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 7 | "complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with five descriptive sentences describing the same image and the bounding box coordinates of each object in the scene. These coordinates are in the form of bounding boxes, represented as (x1, y1, x2, y2) with floating numbers ranging from 0 to 1. These values correspond to the top left x, top left y, bottom right x, and bottom right y. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 8 | "llava_bench_conv": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 9 | "llava_bench_detail": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."}, 10 | "llava_bench_complex": {"role": "Assistant", "prompt": "We would like to request your feedback on the performance of two AI assistants in response to the user question displayed above. The user asks the question on observing an image. For your reference, the visual content in the image is represented with a few sentences describing the image. \nPlease rate the helpfulness, relevance, accuracy, level of details of their responses. Each assistant receives an overall score on a scale of 1 to 10, where a higher score indicates better overall performance.\nPlease first output a single line containing only two values indicating the scores for Assistant 1 and 2, respectively. The two scores are separated by a space.\nIn the subsequent line, please provide a comprehensive explanation of your evaluation, avoiding any potential bias and ensuring that the order in which the responses were presented does not affect your judgment."} 11 | } -------------------------------------------------------------------------------- /llava/eval/webpage/figures/alpaca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SmartEdit/a65f1262dfcba68c138ea95fe9936df1bd2c111d/llava/eval/webpage/figures/alpaca.png -------------------------------------------------------------------------------- /llava/eval/webpage/figures/bard.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SmartEdit/a65f1262dfcba68c138ea95fe9936df1bd2c111d/llava/eval/webpage/figures/bard.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/chatgpt.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/llama.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SmartEdit/a65f1262dfcba68c138ea95fe9936df1bd2c111d/llava/eval/webpage/figures/llama.jpg -------------------------------------------------------------------------------- /llava/eval/webpage/figures/swords_FILL0_wght300_GRAD0_opsz48.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /llava/eval/webpage/figures/vicuna.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SmartEdit/a65f1262dfcba68c138ea95fe9936df1bd2c111d/llava/eval/webpage/figures/vicuna.jpeg -------------------------------------------------------------------------------- /llava/eval/webpage/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots 7 | 8 | 9 | 10 | 11 | 12 | 13 | 32 | 33 |
34 |

Who's GPT-4's favorite? Battles between State-of-the-Art Chatbots

35 | 36 | 37 |
38 |
39 | 40 | 41 |
42 |
43 | 44 | 45 |
46 |
47 |
48 |
49 | 50 | 51 |
52 |
53 |
54 | 55 | 56 |
57 |
58 | 59 |
60 |
61 |
62 | other logo 63 |
64 |
65 |
66 |
67 | 68 | 69 |
70 |
71 |
72 |
73 | vicuna logo 74 |
75 |
76 |
77 | 78 |
79 |
80 | 81 | 82 |
83 |
84 |
85 | 86 | 87 |
88 |
89 |
90 |
91 |
92 |
93 | 94 |
95 |
96 | 97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 | Assistant #2 (Vicuna, our model) 112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 | 123 | 124 |
125 |
GPT-4 Evaluation
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 | 135 |
136 |
137 | This website is co-authored with GPT-4. 138 |
139 |
140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /llava/eval/webpage/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; 3 | background-color: #f8f9fa; 4 | } 5 | 6 | .navbar-dark .navbar-nav .nav-link { 7 | color: #f1cf68; 8 | font-size: 1.1rem; 9 | padding: 0.5rem 0.6rem; 10 | } 11 | 12 | .card-header { 13 | font-weight: bold; 14 | } 15 | 16 | .card { 17 | box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); 18 | transition: 0.3s; 19 | } 20 | 21 | .card:hover { 22 | box-shadow: 0 8px 16px rgba(0, 0, 0, 0.2); 23 | } 24 | 25 | button { 26 | transition: background-color 0.3s; 27 | } 28 | 29 | button:hover { 30 | background-color: #007bff; 31 | } 32 | 33 | @media (max-width: 767px) { 34 | .form-row .form-group { 35 | margin-bottom: 10px; 36 | } 37 | } 38 | 39 | /* Extra styles */ 40 | 41 | .expandable-card .card-text-container { 42 | max-height: 200px; 43 | overflow-y: hidden; 44 | position: relative; 45 | } 46 | 47 | .expandable-card.expanded .card-text-container { 48 | max-height: none; 49 | } 50 | 51 | .expand-btn { 52 | position: relative; 53 | display: none; 54 | background-color: rgba(255, 255, 255, 0.8); 55 | color: #510c75; 56 | border-color: transparent; 57 | } 58 | 59 | .expand-btn:hover { 60 | background-color: rgba(200, 200, 200, 0.8); 61 | text-decoration: none; 62 | border-color: transparent; 63 | color: #510c75; 64 | } 65 | 66 | .expand-btn:focus { 67 | outline: none; 68 | text-decoration: none; 69 | } 70 | 71 | .expandable-card:not(.expanded) .card-text-container:after { 72 | content: ""; 73 | position: absolute; 74 | bottom: 0; 75 | left: 0; 76 | width: 100%; 77 | height: 90px; 78 | background: linear-gradient(rgba(255, 255, 255, 0.2), rgba(255, 255, 255, 1)); 79 | } 80 | 81 | .expandable-card:not(.expanded) .expand-btn { 82 | margin-top: -40px; 83 | } 84 | 85 | .card-body { 86 | padding-bottom: 5px; 87 | } 88 | 89 | .vertical-flex-layout { 90 | justify-content: center; 91 | align-items: center; 92 | height: 100%; 93 | display: flex; 94 | flex-direction: column; 95 | gap: 5px; 96 | } 97 | 98 | .figure-img { 99 | max-width: 100%; 100 | height: auto; 101 | } 102 | 103 | .adjustable-font-size { 104 | font-size: calc(0.5rem + 2vw); 105 | } 106 | -------------------------------------------------------------------------------- /llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | 5 | import torch 6 | from transformers import StoppingCriteria 7 | from llava.constants import IMAGE_TOKEN_INDEX 8 | 9 | 10 | def load_image_from_base64(image): 11 | return Image.open(BytesIO(base64.b64decode(image))) 12 | 13 | 14 | def process_images(images, image_processor, model_cfg): 15 | return image_processor(images, return_tensors='pt')['pixel_values'] 16 | 17 | 18 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 19 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 20 | 21 | def insert_separator(X, sep): 22 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 23 | 24 | input_ids = [] 25 | offset = 0 26 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 27 | offset = 1 28 | input_ids.append(prompt_chunks[0][0]) 29 | 30 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 31 | input_ids.extend(x[offset:]) 32 | 33 | if return_tensors is not None: 34 | if return_tensors == 'pt': 35 | return torch.tensor(input_ids, dtype=torch.long) 36 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 37 | return input_ids 38 | 39 | 40 | def get_model_name_from_path(model_path): 41 | model_path = model_path.strip("/") 42 | model_paths = model_path.split("/") 43 | if model_paths[-1].startswith('checkpoint-'): 44 | return model_paths[-2] + "_" + model_paths[-1] 45 | else: 46 | return model_paths[-1] 47 | 48 | 49 | 50 | 51 | class KeywordsStoppingCriteria(StoppingCriteria): 52 | def __init__(self, keywords, tokenizer, input_ids): 53 | self.keywords = keywords 54 | self.keyword_ids = [] 55 | for keyword in keywords: 56 | cur_keyword_ids = tokenizer(keyword).input_ids 57 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 58 | cur_keyword_ids = cur_keyword_ids[1:] 59 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 60 | self.tokenizer = tokenizer 61 | self.start_len = input_ids.shape[1] 62 | 63 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 64 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 65 | offset = min(output_ids.shape[1] - self.start_len, 3) 66 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 67 | for keyword_id in self.keyword_ids: 68 | if output_ids[0, -keyword_id.shape[0]:] == keyword_id: 69 | return True 70 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 71 | for keyword in self.keywords: 72 | if keyword in outputs: 73 | return True 74 | return False 75 | -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig 3 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | LlamaConfig, LlamaModel, LlamaForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | past_key_values: Optional[List[torch.FloatTensor]] = None, 61 | inputs_embeds: Optional[torch.FloatTensor] = None, 62 | labels: Optional[torch.LongTensor] = None, 63 | use_cache: Optional[bool] = None, 64 | output_attentions: Optional[bool] = None, 65 | output_hidden_states: Optional[bool] = None, 66 | images: Optional[torch.FloatTensor] = None, 67 | return_dict: Optional[bool] = None, 68 | ) -> Union[Tuple, CausalLMOutputWithPast]: 69 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 70 | output_hidden_states = ( 71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 72 | ) 73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 74 | 75 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 76 | 77 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 78 | outputs = self.model( 79 | input_ids=input_ids, 80 | attention_mask=attention_mask, 81 | past_key_values=past_key_values, 82 | inputs_embeds=inputs_embeds, 83 | use_cache=use_cache, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict 87 | ) 88 | 89 | hidden_states = outputs[0] 90 | logits = self.lm_head(hidden_states) 91 | 92 | loss = None 93 | if labels is not None: 94 | # Shift so that tokens < n predict n 95 | shift_logits = logits[..., :-1, :].contiguous() 96 | shift_labels = labels[..., 1:].contiguous() 97 | # Flatten the tokens 98 | loss_fct = CrossEntropyLoss() 99 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 100 | shift_labels = shift_labels.view(-1) 101 | # Enable model/pipeline parallelism 102 | shift_labels = shift_labels.to(shift_logits.device) 103 | loss = loss_fct(shift_logits, shift_labels) 104 | 105 | if not return_dict: 106 | output = (logits,) + outputs[1:] 107 | return (loss,) + output if loss is not None else output 108 | 109 | return CausalLMOutputWithPast( 110 | loss=loss, 111 | logits=logits, 112 | past_key_values=outputs.past_key_values, 113 | hidden_states=outputs.hidden_states, 114 | attentions=outputs.attentions, 115 | ) 116 | 117 | def prepare_inputs_for_generation( 118 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 119 | ): 120 | if past_key_values: 121 | input_ids = input_ids[:, -1:] 122 | 123 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 124 | if inputs_embeds is not None and past_key_values is None: 125 | model_inputs = {"inputs_embeds": inputs_embeds} 126 | else: 127 | model_inputs = {"input_ids": input_ids} 128 | 129 | model_inputs.update( 130 | { 131 | "past_key_values": past_key_values, 132 | "use_cache": kwargs.get("use_cache"), 133 | "attention_mask": attention_mask, 134 | "images": kwargs.get("images", None), 135 | } 136 | ) 137 | return model_inputs 138 | 139 | AutoConfig.register("llava", LlavaConfig) 140 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 141 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple 17 | import warnings 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | import math 22 | 23 | from transformers import AutoConfig, AutoModelForCausalLM 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel 27 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMPTConfig(MPTConfig): 31 | model_type = "llava_mpt" 32 | 33 | 34 | class LlavaMPTModel(LlavaMetaModel, MPTModel): 35 | config_class = LlavaMPTConfig 36 | 37 | def __init__(self, config: MPTConfig): 38 | config.hidden_size = config.d_model 39 | super(LlavaMPTModel, self).__init__(config) 40 | 41 | def embed_tokens(self, x): 42 | return self.wte(x) 43 | 44 | 45 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): 46 | config_class = LlavaMPTConfig 47 | supports_gradient_checkpointing = True 48 | 49 | def __init__(self, config): 50 | super(MPTForCausalLM, self).__init__(config) 51 | 52 | if not config.tie_word_embeddings: 53 | raise ValueError('MPTForCausalLM only supports tied word embeddings') 54 | self.transformer = LlavaMPTModel(config) 55 | self.logit_scale = None 56 | if config.logit_scale is not None: 57 | logit_scale = config.logit_scale 58 | if isinstance(logit_scale, str): 59 | if logit_scale == 'inv_sqrt_d_model': 60 | logit_scale = 1 / math.sqrt(config.d_model) 61 | else: 62 | raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 63 | self.logit_scale = logit_scale 64 | 65 | def get_model(self): 66 | return self.transformer 67 | 68 | def _set_gradient_checkpointing(self, module, value=False): 69 | if isinstance(module, LlavaMPTModel): 70 | module.gradient_checkpointing = value 71 | 72 | def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): 73 | return_dict = return_dict if return_dict is not None else self.config.return_dict 74 | use_cache = use_cache if use_cache is not None else self.config.use_cache 75 | 76 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 77 | outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) 78 | logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight) 79 | if self.logit_scale is not None: 80 | if self.logit_scale == 0: 81 | warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') 82 | logits *= self.logit_scale 83 | loss = None 84 | if labels is not None: 85 | labels = torch.roll(labels, shifts=-1) 86 | labels[:, -1] = -100 87 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) 88 | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) 89 | 90 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 91 | if inputs_embeds is not None: 92 | raise NotImplementedError('inputs_embeds is not implemented for MPT yet') 93 | attention_mask = kwargs['attention_mask'].bool() 94 | if attention_mask[:, -1].sum() != attention_mask.shape[0]: 95 | raise NotImplementedError('MPT does not support generation with right padding.') 96 | if self.transformer.attn_uses_sequence_id and self.training: 97 | sequence_id = torch.zeros_like(input_ids[:1]) 98 | else: 99 | sequence_id = None 100 | if past_key_values is not None: 101 | input_ids = input_ids[:, -1].unsqueeze(-1) 102 | if self.transformer.prefix_lm: 103 | prefix_mask = torch.ones_like(attention_mask) 104 | if kwargs.get('use_cache') == False: 105 | raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') 106 | else: 107 | prefix_mask = None 108 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)} 109 | 110 | 111 | AutoConfig.register("llava_mpt", LlavaMPTConfig) 112 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) 113 | -------------------------------------------------------------------------------- /llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 4 | NUM_SENTINEL_TOKENS: int = 100 5 | 6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 7 | """Adds sentinel tokens and padding token (if missing). 8 | 9 | Expands the tokenizer vocabulary to include sentinel tokens 10 | used in mixture-of-denoiser tasks as well as a padding token. 11 | 12 | All added tokens are added as special tokens. No tokens are 13 | added if sentinel tokens and padding token already exist. 14 | """ 15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] 16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 17 | if tokenizer.pad_token is None: 18 | tokenizer.add_tokens('', special_tokens=True) 19 | tokenizer.pad_token = '' 20 | assert tokenizer.pad_token_id is not None 21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) 22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 23 | tokenizer.sentinel_token_ids = _sentinel_token_ids 24 | 25 | class AutoTokenizerForMOD(AutoTokenizer): 26 | """AutoTokenizer + Adaptation for MOD. 27 | 28 | A simple wrapper around AutoTokenizer to make instantiating 29 | an MOD-adapted tokenizer a bit easier. 30 | 31 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 32 | a padding token, and a property to get the token ids of the 33 | sentinel tokens. 34 | """ 35 | 36 | @classmethod 37 | def from_pretrained(cls, *args, **kwargs): 38 | """See `AutoTokenizer.from_pretrained` docstring.""" 39 | tokenizer = super().from_pretrained(*args, **kwargs) 40 | adapt_tokenizer_for_denoising(tokenizer) 41 | return tokenizer -------------------------------------------------------------------------------- /llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs): 23 | del kwargs 24 | super().__init__() 25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 27 | self.norm_1 = norm_class(d_model, device=device) 28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device) 29 | self.norm_2 = norm_class(d_model, device=device) 30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 33 | 34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 35 | a = self.norm_1(x) 36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 37 | x = x + self.resid_attn_dropout(b) 38 | m = self.norm_2(x) 39 | n = self.ffn(m) 40 | x = x + self.resid_ffn_dropout(n) 41 | return (x, attn_weights, past_key_value) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/configuration_mpt.py: -------------------------------------------------------------------------------- 1 | """A HuggingFace-style model configuration.""" 2 | from typing import Dict, Optional, Union 3 | from transformers import PretrainedConfig 4 | attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8} 5 | init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0} 6 | 7 | class MPTConfig(PretrainedConfig): 8 | model_type = 'mpt' 9 | 10 | def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs): 11 | """The MPT configuration class. 12 | 13 | Args: 14 | d_model (int): The size of the embedding dimension of the model. 15 | n_heads (int): The number of attention heads. 16 | n_layers (int): The number of layers in the model. 17 | expansion_ratio (int): The ratio of the up/down scale in the MLP. 18 | max_seq_len (int): The maximum sequence length of the model. 19 | vocab_size (int): The size of the vocabulary. 20 | resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. 21 | emb_pdrop (float): The dropout probability for the embedding layer. 22 | learned_pos_emb (bool): Whether to use learned positional embeddings 23 | attn_config (Dict): A dictionary used to configure the model's attention module: 24 | attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention 25 | attn_pdrop (float): The dropout probability for the attention layers. 26 | attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. 27 | qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. 28 | clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to 29 | this value. 30 | softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, 31 | use the default scale of ``1/sqrt(d_keys)``. 32 | prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an 33 | extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix 34 | can attend to one another bi-directionally. Tokens outside the prefix use causal attention. 35 | attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. 36 | When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates 37 | which sub-sequence each token belongs to. 38 | Defaults to ``False`` meaning any provided `sequence_id` will be ignored. 39 | alibi (bool): Whether to use the alibi bias instead of position embeddings. 40 | alibi_bias_max (int): The maximum value of the alibi bias. 41 | init_device (str): The device to use for parameter initialization. 42 | logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. 43 | no_bias (bool): Whether to use bias in all layers. 44 | verbose (int): The verbosity level. 0 is silent. 45 | embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. 46 | norm_type (str): choose type of norm to use 47 | multiquery_attention (bool): Whether to use multiquery attention implementation. 48 | use_cache (bool): Whether or not the model should return the last key/values attentions 49 | init_config (Dict): A dictionary used to configure the model initialization: 50 | init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', 51 | 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 52 | 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. 53 | init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. 54 | emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. 55 | emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution 56 | used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. 57 | init_std (float): The standard deviation of the normal distribution used to initialize the model, 58 | if using the baseline_ parameter initialization scheme. 59 | init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. 60 | fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. 61 | init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. 62 | --- 63 | See llmfoundry.models.utils.param_init_fns.py for info on other param init config options 64 | """ 65 | self.d_model = d_model 66 | self.n_heads = n_heads 67 | self.n_layers = n_layers 68 | self.expansion_ratio = expansion_ratio 69 | self.max_seq_len = max_seq_len 70 | self.vocab_size = vocab_size 71 | self.resid_pdrop = resid_pdrop 72 | self.emb_pdrop = emb_pdrop 73 | self.learned_pos_emb = learned_pos_emb 74 | self.attn_config = attn_config 75 | self.init_device = init_device 76 | self.logit_scale = logit_scale 77 | self.no_bias = no_bias 78 | self.verbose = verbose 79 | self.embedding_fraction = embedding_fraction 80 | self.norm_type = norm_type 81 | self.use_cache = use_cache 82 | self.init_config = init_config 83 | if 'name' in kwargs: 84 | del kwargs['name'] 85 | if 'loss_fn' in kwargs: 86 | del kwargs['loss_fn'] 87 | super().__init__(**kwargs) 88 | self._validate_config() 89 | 90 | def _set_config_defaults(self, config, config_defaults): 91 | for (k, v) in config_defaults.items(): 92 | if k not in config: 93 | config[k] = v 94 | return config 95 | 96 | def _validate_config(self): 97 | self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults) 98 | self.init_config = self._set_config_defaults(self.init_config, init_config_defaults) 99 | if self.d_model % self.n_heads != 0: 100 | raise ValueError('d_model must be divisible by n_heads') 101 | if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])): 102 | raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1") 103 | if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: 104 | raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") 105 | if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 106 | raise NotImplementedError('prefix_lm only implemented with torch and triton attention.') 107 | if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 108 | raise NotImplementedError('alibi only implemented with torch and triton attention.') 109 | if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 110 | raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.') 111 | if self.embedding_fraction > 1 or self.embedding_fraction <= 0: 112 | raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!') 113 | if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model': 114 | raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 115 | if self.init_config.get('name', None) is None: 116 | raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.") 117 | if not self.learned_pos_emb and (not self.attn_config['alibi']): 118 | raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.') -------------------------------------------------------------------------------- /llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | class SharedEmbedding(nn.Embedding): 7 | 8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import torch 3 | import torch.nn as nn 4 | 5 | @contextmanager 6 | def init_empty_weights(include_buffers: bool=False): 7 | """Meta initialization context manager. 8 | 9 | A context manager under which models are initialized with all parameters 10 | on the meta device, therefore creating an empty model. Useful when just 11 | initializing the model would blow the available RAM. 12 | 13 | Args: 14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 15 | not to also put all buffers on the meta device while initializing. 16 | 17 | Example: 18 | ```python 19 | import torch.nn as nn 20 | 21 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 22 | with init_empty_weights(): 23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 24 | ``` 25 | 26 | 27 | 28 | Any model created under this context manager has no weights. As such you can't do something like 29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 30 | 31 | 32 | """ 33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: 34 | yield f 35 | 36 | @contextmanager 37 | def init_on_device(device: torch.device, include_buffers: bool=False): 38 | """Device initialization context manager. 39 | 40 | A context manager under which models are initialized with all parameters 41 | on the specified device. 42 | 43 | Args: 44 | device (`torch.device`): Device to initialize all parameters on. 45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 46 | not to also put all buffers on the meta device while initializing. 47 | 48 | Example: 49 | ```python 50 | import torch.nn as nn 51 | 52 | with init_on_device(device=torch.device("cuda")): 53 | tst = nn.Liner(100, 100) # on `cuda` device 54 | ``` 55 | """ 56 | old_register_parameter = nn.Module.register_parameter 57 | if include_buffers: 58 | old_register_buffer = nn.Module.register_buffer 59 | 60 | def register_empty_parameter(module, name, param): 61 | old_register_parameter(module, name, param) 62 | if param is not None: 63 | param_cls = type(module._parameters[name]) 64 | kwargs = module._parameters[name].__dict__ 65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 66 | 67 | def register_empty_buffer(module, name, buffer): 68 | old_register_buffer(module, name, buffer) 69 | if buffer is not None: 70 | module._buffers[name] = module._buffers[name].to(device) 71 | if include_buffers: 72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} 73 | else: 74 | tensor_constructors_to_patch = {} 75 | 76 | def patch_tensor_constructor(fn): 77 | 78 | def wrapper(*args, **kwargs): 79 | kwargs['device'] = device 80 | return fn(*args, **kwargs) 81 | return wrapper 82 | try: 83 | nn.Module.register_parameter = register_empty_parameter 84 | if include_buffers: 85 | nn.Module.register_buffer = register_empty_buffer 86 | for torch_function_name in tensor_constructors_to_patch.keys(): 87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 88 | yield 89 | finally: 90 | nn.Module.register_parameter = old_register_parameter 91 | if include_buffers: 92 | nn.Module.register_buffer = old_register_buffer 93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): 94 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | if weight is not None: 30 | return output * weight 31 | return output 32 | 33 | class RMSNorm(torch.nn.Module): 34 | 35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 36 | super().__init__() 37 | self.eps = eps 38 | if weight: 39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 40 | else: 41 | self.register_parameter('weight', None) 42 | 43 | def forward(self, x): 44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 45 | 46 | class LPRMSNorm(RMSNorm): 47 | 48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 50 | 51 | def forward(self, x): 52 | downcast_x = _cast_if_autocast_enabled(x) 53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 54 | with torch.autocast(enabled=False, device_type=x.device.type): 55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | from .clip_encoder import CLIPVisionTower 2 | 3 | 4 | def build_vision_tower(vision_tower_cfg, **kwargs): 5 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 6 | if vision_tower.startswith("openai") or vision_tower.startswith("laion") or vision_tower.startswith("/group/30098/yuzhouhuang/X_Python_1/vilmedic-main/diffusion_priors-main/LLMSD_x1/LLaVA230730"): 7 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 8 | 9 | raise ValueError(f'Unknown vision tower: {vision_tower}') 10 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | else: 20 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 21 | 22 | def load_model(self): 23 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 24 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 25 | self.vision_tower.requires_grad_(False) 26 | 27 | self.is_loaded = True 28 | 29 | def feature_select(self, image_forward_outs): 30 | image_features = image_forward_outs.hidden_states[self.select_layer] 31 | if self.select_feature == 'patch': 32 | image_features = image_features[:, 1:] 33 | elif self.select_feature == 'cls_patch': 34 | image_features = image_features 35 | else: 36 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 37 | return image_features 38 | 39 | @torch.no_grad() 40 | def forward(self, images): 41 | if type(images) is list: 42 | image_features = [] 43 | for image in images: 44 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 45 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 49 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 50 | 51 | return image_features 52 | 53 | @property 54 | def dummy_feature(self): 55 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 56 | 57 | @property 58 | def dtype(self): 59 | return self.vision_tower.dtype 60 | 61 | @property 62 | def device(self): 63 | return self.vision_tower.device 64 | 65 | @property 66 | def config(self): 67 | if self.is_loaded: 68 | return self.vision_tower.config 69 | else: 70 | return self.cfg_only 71 | 72 | @property 73 | def hidden_size(self): 74 | return self.config.hidden_size 75 | 76 | @property 77 | def num_patches(self): 78 | return (self.config.image_size // self.config.patch_size) ** 2 79 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SmartEdit/a65f1262dfcba68c138ea95fe9936df1bd2c111d/llava/serve/__init__.py -------------------------------------------------------------------------------- /llava/serve/examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SmartEdit/a65f1262dfcba68c138ea95fe9936df1bd2c111d/llava/serve/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /llava/serve/examples/waterview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SmartEdit/a65f1262dfcba68c138ea95fe9936df1bd2c111d/llava/serve/examples/waterview.jpg -------------------------------------------------------------------------------- /llava/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Manually register workers. 3 | 4 | Usage: 5 | python3 -m fastchat.serve.register_worker --controller http://localhost:21001 --worker-name http://localhost:21002 6 | """ 7 | 8 | import argparse 9 | 10 | import requests 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--controller-address", type=str) 15 | parser.add_argument("--worker-name", type=str) 16 | parser.add_argument("--check-heart-beat", action="store_true") 17 | args = parser.parse_args() 18 | 19 | url = args.controller_address + "/register_worker" 20 | data = { 21 | "worker_name": args.worker_name, 22 | "check_heart_beat": args.check_heart_beat, 23 | "worker_status": None, 24 | } 25 | r = requests.post(url, json=data) 26 | assert r.status_code == 200 27 | -------------------------------------------------------------------------------- /llava/serve/test_message.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | 4 | import requests 5 | 6 | from llava.conversation import default_conversation 7 | 8 | 9 | def main(): 10 | if args.worker_address: 11 | worker_addr = args.worker_address 12 | else: 13 | controller_addr = args.controller_address 14 | ret = requests.post(controller_addr + "/refresh_all_workers") 15 | ret = requests.post(controller_addr + "/list_models") 16 | models = ret.json()["models"] 17 | models.sort() 18 | print(f"Models: {models}") 19 | 20 | ret = requests.post(controller_addr + "/get_worker_address", 21 | json={"model": args.model_name}) 22 | worker_addr = ret.json()["address"] 23 | print(f"worker_addr: {worker_addr}") 24 | 25 | if worker_addr == "": 26 | return 27 | 28 | conv = default_conversation.copy() 29 | conv.append_message(conv.roles[0], args.message) 30 | prompt = conv.get_prompt() 31 | 32 | headers = {"User-Agent": "LLaVA Client"} 33 | pload = { 34 | "model": args.model_name, 35 | "prompt": prompt, 36 | "max_new_tokens": args.max_new_tokens, 37 | "temperature": 0.7, 38 | "stop": conv.sep, 39 | } 40 | response = requests.post(worker_addr + "/worker_generate_stream", headers=headers, 41 | json=pload, stream=True) 42 | 43 | print(prompt.replace(conv.sep, "\n"), end="") 44 | for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"): 45 | if chunk: 46 | data = json.loads(chunk.decode("utf-8")) 47 | output = data["text"].split(conv.sep)[-1] 48 | print(output, end="\r") 49 | print("") 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("--controller-address", type=str, default="http://localhost:21001") 55 | parser.add_argument("--worker-address", type=str) 56 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 57 | parser.add_argument("--max-new-tokens", type=int, default=32) 58 | parser.add_argument("--message", type=str, default= 59 | "Tell me a story with more than 1000 words.") 60 | args = parser.parse_args() 61 | 62 | main() 63 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | import logging 3 | 4 | import torch 5 | from torch import nn 6 | 7 | import transformers 8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 9 | 10 | from einops import rearrange 11 | 12 | try: 13 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 14 | except ImportError: 15 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 16 | from flash_attn.bert_padding import unpad_input, pad_input 17 | 18 | 19 | def forward( 20 | self, 21 | hidden_states: torch.Tensor, 22 | attention_mask: Optional[torch.Tensor] = None, 23 | position_ids: Optional[torch.Tensor] = None, 24 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 25 | output_attentions: bool = False, 26 | use_cache: bool = False, 27 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 28 | """Input shape: Batch x Time x Channel 29 | 30 | attention_mask: [bsz, q_len] 31 | """ 32 | bsz, q_len, _ = hidden_states.size() 33 | 34 | query_states = ( 35 | self.q_proj(hidden_states) 36 | .view(bsz, q_len, self.num_heads, self.head_dim) 37 | .transpose(1, 2) 38 | ) 39 | key_states = ( 40 | self.k_proj(hidden_states) 41 | .view(bsz, q_len, self.num_heads, self.head_dim) 42 | .transpose(1, 2) 43 | ) 44 | value_states = ( 45 | self.v_proj(hidden_states) 46 | .view(bsz, q_len, self.num_heads, self.head_dim) 47 | .transpose(1, 2) 48 | ) 49 | # [bsz, q_len, nh, hd] 50 | # [bsz, nh, q_len, hd] 51 | 52 | kv_seq_len = key_states.shape[-2] 53 | assert past_key_value is None, "past_key_value is not supported" 54 | 55 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 56 | query_states, key_states = apply_rotary_pos_emb( 57 | query_states, key_states, cos, sin, position_ids 58 | ) 59 | # [bsz, nh, t, hd] 60 | assert not output_attentions, "output_attentions is not supported" 61 | assert not use_cache, "use_cache is not supported" 62 | 63 | # Flash attention codes from 64 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 65 | 66 | # transform the data into the format required by flash attention 67 | qkv = torch.stack( 68 | [query_states, key_states, value_states], dim=2 69 | ) # [bsz, nh, 3, q_len, hd] 70 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 71 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 72 | # the attention_mask should be the same as the key_padding_mask 73 | key_padding_mask = attention_mask 74 | 75 | if key_padding_mask is None: 76 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 77 | max_s = q_len 78 | cu_q_lens = torch.arange( 79 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 80 | ) 81 | output = flash_attn_unpadded_qkvpacked_func( 82 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 83 | ) 84 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 85 | else: 86 | nheads = qkv.shape[-2] 87 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 88 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 89 | x_unpad = rearrange( 90 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 91 | ) 92 | output_unpad = flash_attn_unpadded_qkvpacked_func( 93 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 94 | ) 95 | output = rearrange( 96 | pad_input( 97 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 98 | ), 99 | "b s (h d) -> b s h d", 100 | h=nheads, 101 | ) 102 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None 103 | 104 | 105 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 106 | # requires the attention mask to be the same as the key_padding_mask 107 | def _prepare_decoder_attention_mask( 108 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 109 | ): 110 | # [bsz, seq_len] 111 | return attention_mask 112 | 113 | 114 | def replace_llama_attn_with_flash_attn(): 115 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 116 | if cuda_major < 8: 117 | logging.warning( 118 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 119 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 120 | ) 121 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 122 | _prepare_decoder_attention_mask 123 | ) 124 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 125 | -------------------------------------------------------------------------------- /llava/train/llava_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from transformers import Trainer 5 | from typing import Optional 6 | 7 | 8 | def maybe_zero_3(param, ignore_status=False, name=None): 9 | from deepspeed import zero 10 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 11 | if hasattr(param, "ds_id"): 12 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 13 | if not ignore_status: 14 | print(name, 'no ignore status') 15 | with zero.GatheredParameters([param]): 16 | param = param.data.detach().cpu().clone() 17 | else: 18 | param = param.detach().cpu().clone() 19 | return param 20 | 21 | 22 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 23 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 24 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 25 | return to_return 26 | 27 | 28 | class LLaVATrainer(Trainer): 29 | 30 | def _save_checkpoint(self, model, trial, metrics=None): 31 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 32 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 33 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 34 | 35 | run_dir = self._get_output_dir(trial=trial) 36 | output_dir = os.path.join(run_dir, checkpoint_folder) 37 | 38 | # Only save Adapter 39 | keys_to_match = ['mm_projector'] 40 | if getattr(self.args, "use_im_start_end", False): 41 | keys_to_match.extend(['embed_tokens', 'embed_in']) 42 | 43 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 44 | 45 | if self.args.local_rank == 0 or self.args.local_rank == -1: 46 | self.model.config.save_pretrained(output_dir) 47 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 48 | else: 49 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) 50 | 51 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 52 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 53 | pass 54 | else: 55 | super(LLaVATrainer, self)._save(output_dir, state_dict) 56 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SmartEdit/a65f1262dfcba68c138ea95fe9936df1bd2c111d/model/__init__.py -------------------------------------------------------------------------------- /process_HF.py: -------------------------------------------------------------------------------- 1 | """ 2 | # https://huggingface.co/datasets/timbrooks/instructpix2pix-clip-filtered 3 | # https://huggingface.co/datasets/osunlp/MagicBrush 4 | python process_HF.py 5 | """ 6 | 7 | # change the original dataset file format into .arrow file -> InstructPix2Pix + MagicBrush 8 | import pandas as pd 9 | from datasets import Dataset, concatenate_datasets, load_from_disk 10 | import glob 11 | import os 12 | 13 | # Define a generator function that loads Parquet files one by one and converts them into a dataset 14 | def parquet_to_dataset_generator(file_paths): 15 | index = 0 16 | for file_path in file_paths: 17 | print('Number:', index) 18 | df = pd.read_parquet(file_path) 19 | dataset = Dataset.from_pandas(df) 20 | index = index + 1 21 | yield dataset 22 | 23 | # InstructPix2Pix 24 | InstructPix2Pix_file_pattern = './Datasets/InstructPix2PixCLIPFiltered_HF/*.parquet' 25 | InstructPix2Pix_file_paths = glob.glob(InstructPix2Pix_file_pattern) 26 | 27 | # Load Parquet files one by one using a generator function and convert them into a dataset 28 | InstructPix2Pix_parquet_datasets = list(parquet_to_dataset_generator(InstructPix2Pix_file_paths)) 29 | 30 | # Concatenate multiple datasets using the concatenate_datasets function 31 | InstructPix2Pix_merged_dataset = concatenate_datasets(InstructPix2Pix_parquet_datasets) 32 | print(InstructPix2Pix_merged_dataset) 33 | # Dataset({features: ['original_prompt', 'original_image', 'edit_prompt', 'edited_prompt', 'edited_image'], num_rows: 313010}) 34 | 35 | # Save the dataset to disk using the save_to_disk method 36 | InstructPix2Pix_HF_path = './Datasets/InstructPix2PixCLIPFiltered_HF' 37 | InstructPix2Pix_merged_dataset.save_to_disk(InstructPix2Pix_HF_path) 38 | 39 | # load_from_disk 40 | InstructPix2Pix_HF_path = load_from_disk(InstructPix2Pix_HF_path) 41 | print(InstructPix2Pix_HF_path) 42 | 43 | # same for MagicBrush 44 | MagicBrush_file_pattern = './Datasets/MagicBrush_HF/train-*.parquet' 45 | MagicBrush_file_paths = glob.glob(MagicBrush_file_pattern) 46 | MagicBrush_parquet_datasets = list(parquet_to_dataset_generator(MagicBrush_file_paths)) 47 | MagicBrush_merged_dataset = concatenate_datasets(MagicBrush_parquet_datasets) 48 | print(MagicBrush_merged_dataset) 49 | 50 | # load 51 | MagicBruth_HF_path = './Datasets/MagicBruth_HF' 52 | MagicBrush_merged_dataset.save_to_disk(MagicBruth_HF_path) 53 | MagicBruth_HF_path = load_from_disk(MagicBruth_HF_path) 54 | # Dataset({features: ['img_id', 'turn_index', 'source_img', 'mask_img', 'instruction', 'target_img'], num_rows: 8807}) 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | fastapi 3 | gradio==3.23 4 | httpx 5 | markdown2[all] 6 | nh3 7 | prompt_toolkit>=3.0.0 8 | pydantic 9 | requests 10 | rich>=10.0.0 11 | sentencepiece 12 | tokenizers>=0.12.1 13 | transformers==4.28.1 14 | uvicorn 15 | datasets 16 | apache_beam 17 | omegaconf 18 | pytorch-lightning==1.8.4.post0 19 | webdataset 20 | einops 21 | ninja 22 | diffusers==0.20.2 23 | numpy==1.23.4 24 | -------------------------------------------------------------------------------- /scripts/MLLMSD_13b.sh: -------------------------------------------------------------------------------- 1 | """ bash scripts/MLLMSD_13b.sh """ 2 | 3 | # train MLLM-13b + SD 4 | wandb disabled 5 | export WANDB_DISABLED=true 6 | deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_addr 127.0.0.1 --master_port 28457 fastchat/train/DS_MLLMSD11_train.py \ 7 | --max_steps 5000 \ 8 | --model_name_or_path ./checkpoints/vicuna-13b-v1-1 \ 9 | --LLaVA_00001 "./checkpoints/LLaVA-13B-v1/pytorch_model-00001-of-00003.bin" \ 10 | --LLaVA_00002 "./checkpoints/LLaVA-13B-v1/pytorch_model-00003-of-00003.bin" \ 11 | --LLaVA_model_path "./checkpoints/LLaVA-13B-v1" \ 12 | --sd_qformer_version "v1.1-13b" \ 13 | --unet_ckpt "./checkpoints/InstructDiffusion_diffusers/unet/diffusion_pytorch_model.bin" \ 14 | --bf16 True \ 15 | --tf32 True \ 16 | --output_dir ./checkpoints/stage2_MLLMSD_13b \ 17 | --num_train_epochs 20 \ 18 | --per_device_train_batch_size 4 \ 19 | --per_device_eval_batch_size 4 \ 20 | --gradient_accumulation_steps 4 \ 21 | --evaluation_strategy 'no' \ 22 | --save_strategy 'steps' \ 23 | --save_steps 5000 \ 24 | --save_total_limit 3 \ 25 | --learning_rate 1e-5 \ 26 | --lr_scheduler_type 'cosine' \ 27 | --weight_decay 0. \ 28 | --warmup_ratio 0.001 \ 29 | --logging_steps 1 \ 30 | --model_max_length 2048 \ 31 | --gradient_checkpointing True \ 32 | --dataloader_num_workers 16 \ 33 | --ddp_find_unused_parameters True \ 34 | --SD_QFormer_conversation_33tokens "./checkpoints/stage1_CC12M_alignment_13b/embeddings_qformer/checkpoint-150000_embeddings_qformer.bin" \ 35 | --InstructPix2PixDataset_path "./dataset/InstructPix2PixCLIPFiltered_HF" \ 36 | --MagicBrushDataset_path "./dataset/MagicBrush_HF" \ 37 | --LLaVADataset_data_path "./dataset/LLaVA/llava_instruct_150k.json" \ 38 | --LLaVADataset_image_folder "./dataset/coco/train2017" \ 39 | --refcoco_path "./dataset/refcoco" \ 40 | --grefcoco_path "./dataset/grefcoco" \ 41 | --coco_image_path "./dataset/coco" \ 42 | --COCOStuff_mask_path "./dataset/cocostuff" \ 43 | --ReasoningEditingDataset_path "./dataset/SyntheticData/SyntheticData_info_new.json" \ 44 | --ReasoningSegmentationDataset_json_path "./dataset/reason_seg/train" \ 45 | --ReasoningSegmentationDataset_image_path "./dataset/reason_seg/train" \ 46 | --ReasoningSegmentationDataset_binary_mask_path "./dataset/reason_seg/train_binary_mask" \ 47 | --deepspeed scripts/zero2_mixed.json \ 48 | -------------------------------------------------------------------------------- /scripts/MLLMSD_7b.sh: -------------------------------------------------------------------------------- 1 | """ bash scripts/MLLMSD_7b.sh """ 2 | 3 | # train MLLM-7b + SD 4 | wandb disabled 5 | export WANDB_DISABLED=true 6 | deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_addr 127.0.0.1 --master_port 28457 train/DS_MLLMSD11_train.py \ 7 | --max_steps 5000 \ 8 | --model_name_or_path ./checkpoints/vicuna-7b-v1-1 \ 9 | --LLaVA_00001 "./checkpoints/LLaVA-7B-v1/pytorch_model-00001-of-00002.bin" \ 10 | --LLaVA_00002 "./checkpoints/LLaVA-7B-v1/pytorch_model-00002-of-00002.bin" \ 11 | --LLaVA_model_path "./checkpoints/LLaVA-7B-v1" \ 12 | --sd_qformer_version "v1.1-7b" \ 13 | --unet_ckpt "./checkpoints/InstructDiffusion_diffusers/unet/diffusion_pytorch_model.bin" \ 14 | --bf16 True \ 15 | --tf32 True \ 16 | --output_dir ./checkpoints/stage2_MLLMSD_7b \ 17 | --num_train_epochs 20 \ 18 | --per_device_train_batch_size 4 \ 19 | --per_device_eval_batch_size 4 \ 20 | --gradient_accumulation_steps 4 \ 21 | --evaluation_strategy 'no' \ 22 | --save_strategy 'steps' \ 23 | --save_steps 5000 \ 24 | --save_total_limit 3 \ 25 | --learning_rate 1e-5 \ 26 | --lr_scheduler_type 'cosine' \ 27 | --weight_decay 0. \ 28 | --warmup_ratio 0.001 \ 29 | --logging_steps 1 \ 30 | --model_max_length 2048 \ 31 | --gradient_checkpointing True \ 32 | --dataloader_num_workers 16 \ 33 | --ddp_find_unused_parameters True \ 34 | --SD_QFormer_conversation_33tokens "./checkpoints/stage1_CC12M_alignment_7b/embeddings_qformer/checkpoint-150000_embeddings_qformer.bin" \ 35 | --InstructPix2PixDataset_path "./dataset/InstructPix2PixCLIPFiltered_HF" \ 36 | --MagicBrushDataset_path "./dataset/MagicBrush_HF" \ 37 | --LLaVADataset_data_path "./dataset/LLaVA/llava_instruct_150k.json" \ 38 | --LLaVADataset_image_folder "./dataset/coco/train2017" \ 39 | --refcoco_path "./dataset/refcoco" \ 40 | --grefcoco_path "./dataset/grefcoco" \ 41 | --coco_image_path "./dataset/coco" \ 42 | --COCOStuff_mask_path "./dataset/cocostuff" \ 43 | --ReasoningEditingDataset_path "./dataset/SyntheticData/SyntheticData_info_new.json" \ 44 | --ReasoningSegmentationDataset_json_path "./dataset/reason_seg/train" \ 45 | --ReasoningSegmentationDataset_image_path "./dataset/reason_seg/train" \ 46 | --ReasoningSegmentationDataset_binary_mask_path "./dataset/reason_seg/train_binary_mask" \ 47 | --deepspeed scripts/zero2_mixed.json \ 48 | -------------------------------------------------------------------------------- /scripts/SmartEdit_13b.sh: -------------------------------------------------------------------------------- 1 | """ bash scripts/SmartEdit_13b.sh """ 2 | 3 | # train SmartEdit-13b 4 | wandb disabled 5 | export WANDB_DISABLED=true 6 | deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_addr 127.0.0.1 --master_port 28458 train/DS_SmartEdit_train.py \ 7 | --max_steps 15000 \ 8 | --model_name_or_path ./checkpoints/vicuna-13b-v1-1 \ 9 | --LLaVA_00001 "./checkpoints/LLaVA-13B-v1/pytorch_model-00001-of-00003.bin" \ 10 | --LLaVA_00002 "./checkpoints/LLaVA-13B-v1/pytorch_model-00003-of-00003.bin" \ 11 | --LLaVA_model_path "./checkpoints/LLaVA-13B-v1" \ 12 | --sd_qformer_version "v1.1-13b" \ 13 | --pretrained_LLaMA "./checkpoints/stage2_MLLMSD_13b/LLM-5000/adapter_model.bin" \ 14 | --pretrained_model "./checkpoints/stage2_MLLMSD_13b/embeddings_qformer/checkpoint-5000_embeddings_qformer.bin" \ 15 | --pretrained_unet "./checkpoints/stage2_MLLMSD_13b/unet-5000/adapter_model.bin" \ 16 | --bf16 True \ 17 | --tf32 True \ 18 | --output_dir "./checkpoints/SmartEdit_13b_ckpt" \ 19 | --num_train_epochs 20 \ 20 | --per_device_train_batch_size 2 \ 21 | --per_device_eval_batch_size 4 \ 22 | --gradient_accumulation_steps 8 \ 23 | --evaluation_strategy 'no' \ 24 | --save_strategy 'steps' \ 25 | --save_steps 5000 \ 26 | --save_total_limit 10 \ 27 | --learning_rate 1e-5 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0. \ 30 | --lr_scheduler_type 'cosine' \ 31 | --logging_steps 1 \ 32 | --model_max_length 2048 \ 33 | --gradient_checkpointing True \ 34 | --dataloader_num_workers 16 \ 35 | --ddp_find_unused_parameters True \ 36 | --InstructPix2PixDataset_path "./dataset/InstructPix2PixCLIPFiltered_HF" \ 37 | --MagicBrushDataset_path "./dataset/MagicBrush_HF" \ 38 | --LLaVADataset_data_path "./dataset/LLaVA/llava_instruct_150k.json" \ 39 | --LLaVADataset_image_folder "./dataset/coco/train2017" \ 40 | --refcoco_path "./dataset/refcoco" \ 41 | --grefcoco_path "./dataset/grefcoco" \ 42 | --coco_image_path "./dataset/coco" \ 43 | --COCOStuff_mask_path "./dataset/cocostuff" \ 44 | --ReasoningEditingDataset_path "./dataset/SyntheticData/SyntheticData_info_new.json" \ 45 | --ReasoningSegmentationDataset_json_path "./dataset/reason_seg/train" \ 46 | --ReasoningSegmentationDataset_image_path "./dataset/reason_seg/train" \ 47 | --ReasoningSegmentationDataset_binary_mask_path "./dataset/reason_seg/train_binary_mask" \ 48 | --deepspeed scripts/zero2_mixed.json \ 49 | -------------------------------------------------------------------------------- /scripts/SmartEdit_7b.sh: -------------------------------------------------------------------------------- 1 | """ bash scripts/SmartEdit_7b.sh """ 2 | 3 | # train SmartEdit-7b 4 | wandb disabled 5 | export WANDB_DISABLED=true 6 | deepspeed --include localhost:0,1,2,3,4,5,6,7 --master_addr 127.0.0.1 --master_port 28458 train/DS_SmartEdit_train.py \ 7 | --max_steps 15000 \ 8 | --model_name_or_path ./checkpoints/vicuna-7b-v1-1 \ 9 | --LLaVA_00001 "./checkpoints/LLaVA-7B-v1/pytorch_model-00001-of-00002.bin" \ 10 | --LLaVA_00002 "./checkpoints/LLaVA-7B-v1/pytorch_model-00002-of-00002.bin" \ 11 | --LLaVA_model_path "./checkpoints/LLaVA-7B-v1" \ 12 | --sd_qformer_version "v1.1-7b" \ 13 | --pretrained_LLaMA "./checkpoints/stage2_MLLMSD_7b/LLM-5000/adapter_model.bin" \ 14 | --pretrained_model "./checkpoints/stage2_MLLMSD_7b/embeddings_qformer/checkpoint-5000_embeddings_qformer.bin" \ 15 | --pretrained_unet "./checkpoints/stage2_MLLMSD_7b/unet-5000/adapter_model.bin" \ 16 | --bf16 True \ 17 | --tf32 True \ 18 | --output_dir "./checkpoints/SmartEdit_7b_ckpt" \ 19 | --num_train_epochs 20 \ 20 | --per_device_train_batch_size 2 \ 21 | --per_device_eval_batch_size 4 \ 22 | --gradient_accumulation_steps 8 \ 23 | --evaluation_strategy 'no' \ 24 | --save_strategy 'steps' \ 25 | --save_steps 5000 \ 26 | --save_total_limit 10 \ 27 | --learning_rate 1e-5 \ 28 | --weight_decay 0. \ 29 | --warmup_ratio 0. \ 30 | --lr_scheduler_type 'cosine' \ 31 | --logging_steps 1 \ 32 | --model_max_length 2048 \ 33 | --gradient_checkpointing True \ 34 | --dataloader_num_workers 16 \ 35 | --ddp_find_unused_parameters True \ 36 | --InstructPix2PixDataset_path "./dataset/InstructPix2PixCLIPFiltered_HF" \ 37 | --MagicBrushDataset_path "./dataset/MagicBrush_HF" \ 38 | --LLaVADataset_data_path "./dataset/LLaVA/llava_instruct_150k.json" \ 39 | --LLaVADataset_image_folder "./dataset/coco/train2017" \ 40 | --refcoco_path "./dataset/refcoco" \ 41 | --grefcoco_path "./dataset/grefcoco" \ 42 | --coco_image_path "./dataset/coco" \ 43 | --COCOStuff_mask_path "./dataset/cocostuff" \ 44 | --ReasoningEditingDataset_path "./dataset/SyntheticData/SyntheticData_info_new.json" \ 45 | --ReasoningSegmentationDataset_json_path "./dataset/reason_seg/train" \ 46 | --ReasoningSegmentationDataset_image_path "./dataset/reason_seg/train" \ 47 | --ReasoningSegmentationDataset_binary_mask_path "./dataset/reason_seg/train_binary_mask" \ 48 | --deepspeed scripts/zero2_mixed.json \ 49 | -------------------------------------------------------------------------------- /scripts/TrainStage1_13b.sh: -------------------------------------------------------------------------------- 1 | """ 2 | wget https://huggingface.co/lmsys/vicuna-13b-v1.1/resolve/main/* 3 | bash scripts/TrainStage1_13b.sh 4 | """ 5 | 6 | # CC12M + llava1.1-13b 7 | torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/TrainStage1.py \ 8 | --max_steps 150000 \ 9 | --model_name_or_path ./checkpoints/vicuna-13b-v1-1 \ 10 | --LLaVA_model_path_v1_1_13b ./checkpoints/LLaVA-13B-v1 \ 11 | --data_path ./dataset/cc12m.tsv \ 12 | --template_data_path ./data/conv_template_cap_to_img.txt \ 13 | --bf16 True \ 14 | --output_dir ./checkpoints/stage1_CC12M_alignment_7b \ 15 | --num_new_tokens 32 \ 16 | --num_train_epochs 10 \ 17 | --per_device_train_batch_size 16 \ 18 | --per_device_eval_batch_size 4 \ 19 | --gradient_accumulation_steps 1 \ 20 | --evaluation_strategy "no" \ 21 | --save_strategy "steps" \ 22 | --save_steps 10000 \ 23 | --save_total_limit 3 \ 24 | --learning_rate 2e-4 \ 25 | --weight_decay 0. \ 26 | --warmup_ratio 0.04 \ 27 | --lr_scheduler_type "cosine" \ 28 | --logging_steps 1 \ 29 | --tf32 True \ 30 | --model_max_length 256 \ 31 | --gradient_checkpointing True \ 32 | --dataloader_num_workers 8 \ 33 | --lazy_preprocess True \ 34 | --LLaVA_version "v1.1-13b" \ 35 | --report_to "none" 36 | -------------------------------------------------------------------------------- /scripts/TrainStage1_7b.sh: -------------------------------------------------------------------------------- 1 | """ 2 | wget https://huggingface.co/lmsys/vicuna-7b-v1.1/resolve/main/* 3 | bash scripts/TrainStage1_7b.sh 4 | """ 5 | 6 | # CC12M + llava1.1-7b 7 | torchrun --nproc_per_node=8 --master_port=20001 fastchat/train/TrainStage1.py \ 8 | --max_steps 150000 \ 9 | --model_name_or_path ./checkpoints/vicuna-7b-v1-1 \ 10 | --LLaVA_model_v1_1_7b_path ./checkpoints/LLaVA-7B-v1 \ 11 | --data_path ./dataset/cc12m.tsv \ 12 | --template_data_path ./data/conv_template_cap_to_img.txt \ 13 | --bf16 True \ 14 | --output_dir ./checkpoints/stage1_CC12M_alignment_7b \ 15 | --num_new_tokens 32 \ 16 | --num_train_epochs 10 \ 17 | --per_device_train_batch_size 16 \ 18 | --per_device_eval_batch_size 4 \ 19 | --gradient_accumulation_steps 1 \ 20 | --evaluation_strategy "no" \ 21 | --save_strategy "steps" \ 22 | --save_steps 10000 \ 23 | --save_total_limit 3 \ 24 | --learning_rate 2e-4 \ 25 | --weight_decay 0. \ 26 | --warmup_ratio 0.04 \ 27 | --lr_scheduler_type "cosine" \ 28 | --logging_steps 1 \ 29 | --tf32 True \ 30 | --model_max_length 256 \ 31 | --gradient_checkpointing True \ 32 | --dataloader_num_workers 8 \ 33 | --lazy_preprocess True \ 34 | --LLaVA_version "v1.1-7b" \ 35 | --report_to "none" 36 | -------------------------------------------------------------------------------- /scripts/zero2_mixed.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "optimizer": { 12 | "type": "AdamW", 13 | "params": { 14 | "lr": "auto", 15 | "betas": "auto", 16 | "eps": "auto", 17 | "weight_decay": "auto" 18 | } 19 | }, 20 | 21 | "scheduler": { 22 | "type": "WarmupLR", 23 | "params": { 24 | "warmup_min_lr": "auto", 25 | "warmup_max_lr": "auto", 26 | "warmup_num_steps": "auto" 27 | } 28 | }, 29 | 30 | "zero_optimization": { 31 | "stage": 2, 32 | "allgather_partitions": true, 33 | "allgather_bucket_size": 2e8, 34 | "overlap_comm": true, 35 | "reduce_scatter": true, 36 | "reduce_bucket_size": 2e8, 37 | "contiguous_gradients": true 38 | }, 39 | 40 | "gradient_accumulation_steps": "auto", 41 | "gradient_clipping": "auto", 42 | "steps_per_print": 2000, 43 | "train_batch_size": "auto", 44 | "train_micro_batch_size_per_gpu": "auto", 45 | "wall_clock_breakdown": false 46 | } -------------------------------------------------------------------------------- /scripts/zero2_offload_mixed.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "optimizer": { 12 | "type": "AdamW", 13 | "params": { 14 | "lr": "auto", 15 | "betas": "auto", 16 | "eps": "auto", 17 | "weight_decay": "auto" 18 | } 19 | }, 20 | 21 | "scheduler": { 22 | "type": "WarmupLR", 23 | "params": { 24 | "warmup_min_lr": "auto", 25 | "warmup_max_lr": "auto", 26 | "warmup_num_steps": "auto" 27 | } 28 | }, 29 | 30 | "zero_optimization": { 31 | "stage": 2, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "allgather_partitions": true, 37 | "allgather_bucket_size": 2e8, 38 | "overlap_comm": true, 39 | "reduce_scatter": true, 40 | "reduce_bucket_size": 2e8, 41 | "contiguous_gradients": true 42 | }, 43 | 44 | "gradient_accumulation_steps": "auto", 45 | "gradient_clipping": "auto", 46 | "steps_per_print": 2000, 47 | "train_batch_size": "auto", 48 | "train_micro_batch_size_per_gpu": "auto", 49 | "wall_clock_breakdown": false 50 | } -------------------------------------------------------------------------------- /train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TencentARC/SmartEdit/a65f1262dfcba68c138ea95fe9936df1bd2c111d/train/__init__.py -------------------------------------------------------------------------------- /train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 8 | 9 | from einops import rearrange 10 | 11 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 12 | from flash_attn.bert_padding import unpad_input, pad_input 13 | 14 | 15 | def forward( 16 | self, 17 | hidden_states: torch.Tensor, 18 | attention_mask: Optional[torch.Tensor] = None, 19 | position_ids: Optional[torch.Tensor] = None, 20 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 21 | output_attentions: bool = False, 22 | use_cache: bool = False, 23 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 24 | """Input shape: Batch x Time x Channel 25 | 26 | attention_mask: [bsz, q_len] 27 | """ 28 | bsz, q_len, _ = hidden_states.size() 29 | 30 | query_states = ( 31 | self.q_proj(hidden_states) 32 | .view(bsz, q_len, self.num_heads, self.head_dim) 33 | .transpose(1, 2) 34 | ) 35 | key_states = ( 36 | self.k_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | value_states = ( 41 | self.v_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | # [bsz, q_len, nh, hd] 46 | # [bsz, nh, q_len, hd] 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | assert past_key_value is None, "past_key_value is not supported" 50 | 51 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 52 | query_states, key_states = apply_rotary_pos_emb( 53 | query_states, key_states, cos, sin, position_ids 54 | ) 55 | # [bsz, nh, t, hd] 56 | assert not output_attentions, "output_attentions is not supported" 57 | assert not use_cache, "use_cache is not supported" 58 | 59 | # Flash attention codes from 60 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 61 | 62 | # transform the data into the format required by flash attention 63 | qkv = torch.stack( 64 | [query_states, key_states, value_states], dim=2 65 | ) # [bsz, nh, 3, q_len, hd] 66 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 67 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 68 | # the attention_mask should be the same as the key_padding_mask 69 | key_padding_mask = attention_mask 70 | 71 | if key_padding_mask is None: 72 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 73 | max_s = q_len 74 | cu_q_lens = torch.arange( 75 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 76 | ) 77 | output = flash_attn_unpadded_qkvpacked_func( 78 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 79 | ) 80 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 81 | else: 82 | nheads = qkv.shape[-2] 83 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 84 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 85 | x_unpad = rearrange( 86 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 87 | ) 88 | output_unpad = flash_attn_unpadded_qkvpacked_func( 89 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 90 | ) 91 | output = rearrange( 92 | pad_input( 93 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 94 | ), 95 | "b s (h d) -> b s h d", 96 | h=nheads, 97 | ) 98 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None 99 | 100 | 101 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 102 | # requires the attention mask to be the same as the key_padding_mask 103 | def _prepare_decoder_attention_mask( 104 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 105 | ): 106 | # [bsz, seq_len] 107 | return attention_mask 108 | 109 | 110 | def replace_llama_attn_with_flash_attn(): 111 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 112 | _prepare_decoder_attention_mask 113 | ) 114 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 115 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from asyncio import AbstractEventLoop 2 | import json 3 | import logging 4 | import logging.handlers 5 | import os 6 | import platform 7 | import sys 8 | from typing import AsyncGenerator, Generator 9 | import warnings 10 | 11 | import requests 12 | import torch 13 | 14 | from fastchat.constants import LOGDIR 15 | 16 | 17 | handler = None 18 | 19 | 20 | def build_logger(logger_name, logger_filename): 21 | global handler 22 | 23 | formatter = logging.Formatter( 24 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 25 | datefmt="%Y-%m-%d %H:%M:%S", 26 | ) 27 | 28 | # Set the format of root handlers 29 | if not logging.getLogger().handlers: 30 | if sys.version_info[1] >= 9: 31 | # This is for windows 32 | logging.basicConfig(level=logging.INFO, encoding="utf-8") 33 | else: 34 | if platform.system() == "Windows": 35 | warnings.warn( 36 | "If you are running on Windows, " 37 | "we recommend you use Python >= 3.9 for UTF-8 encoding." 38 | ) 39 | logging.basicConfig(level=logging.INFO) 40 | logging.getLogger().handlers[0].setFormatter(formatter) 41 | 42 | # Redirect stdout and stderr to loggers 43 | stdout_logger = logging.getLogger("stdout") 44 | stdout_logger.setLevel(logging.INFO) 45 | sl = StreamToLogger(stdout_logger, logging.INFO) 46 | sys.stdout = sl 47 | 48 | stderr_logger = logging.getLogger("stderr") 49 | stderr_logger.setLevel(logging.ERROR) 50 | sl = StreamToLogger(stderr_logger, logging.ERROR) 51 | sys.stderr = sl 52 | 53 | # Get logger 54 | logger = logging.getLogger(logger_name) 55 | logger.setLevel(logging.INFO) 56 | 57 | # Add a file handler for all loggers 58 | if handler is None: 59 | os.makedirs(LOGDIR, exist_ok=True) 60 | filename = os.path.join(LOGDIR, logger_filename) 61 | handler = logging.handlers.TimedRotatingFileHandler( 62 | filename, when="D", utc=True, encoding="utf-8" 63 | ) 64 | handler.setFormatter(formatter) 65 | 66 | for name, item in logging.root.manager.loggerDict.items(): 67 | if isinstance(item, logging.Logger): 68 | item.addHandler(handler) 69 | 70 | return logger 71 | 72 | 73 | class StreamToLogger(object): 74 | """ 75 | Fake file-like stream object that redirects writes to a logger instance. 76 | """ 77 | 78 | def __init__(self, logger, log_level=logging.INFO): 79 | self.terminal = sys.stdout 80 | self.logger = logger 81 | self.log_level = log_level 82 | self.linebuf = "" 83 | 84 | def __getattr__(self, attr): 85 | return getattr(self.terminal, attr) 86 | 87 | def write(self, buf): 88 | temp_linebuf = self.linebuf + buf 89 | self.linebuf = "" 90 | for line in temp_linebuf.splitlines(True): 91 | # From the io.TextIOWrapper docs: 92 | # On output, if newline is None, any '\n' characters written 93 | # are translated to the system default line separator. 94 | # By default sys.stdout.write() expects '\n' newlines and then 95 | # translates them so this is still cross platform. 96 | if line[-1] == "\n": 97 | encoded_message = line.encode("utf-8", "ignore").decode("utf-8") 98 | self.logger.log(self.log_level, encoded_message.rstrip()) 99 | else: 100 | self.linebuf += line 101 | 102 | def flush(self): 103 | if self.linebuf != "": 104 | encoded_message = self.linebuf.encode("utf-8", "ignore").decode("utf-8") 105 | self.logger.log(self.log_level, encoded_message.rstrip()) 106 | self.linebuf = "" 107 | 108 | 109 | def disable_torch_init(): 110 | """ 111 | Disable the redundant torch default initialization to accelerate model creation. 112 | """ 113 | import torch 114 | 115 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 116 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 117 | 118 | 119 | def get_gpu_memory(max_gpus=None): 120 | """Get available memory for each GPU.""" 121 | gpu_memory = [] 122 | num_gpus = ( 123 | torch.cuda.device_count() 124 | if max_gpus is None 125 | else min(max_gpus, torch.cuda.device_count()) 126 | ) 127 | 128 | for gpu_id in range(num_gpus): 129 | with torch.cuda.device(gpu_id): 130 | device = torch.cuda.current_device() 131 | gpu_properties = torch.cuda.get_device_properties(device) 132 | total_memory = gpu_properties.total_memory / (1024**3) 133 | allocated_memory = torch.cuda.memory_allocated() / (1024**3) 134 | available_memory = total_memory - allocated_memory 135 | gpu_memory.append(available_memory) 136 | return gpu_memory 137 | 138 | 139 | def violates_moderation(text): 140 | """ 141 | Check whether the text violates OpenAI moderation API. 142 | """ 143 | url = "https://api.openai.com/v1/moderations" 144 | headers = { 145 | "Content-Type": "application/json", 146 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"], 147 | } 148 | text = text.replace("\n", "") 149 | data = "{" + '"input": ' + f'"{text}"' + "}" 150 | data = data.encode("utf-8") 151 | try: 152 | ret = requests.post(url, headers=headers, data=data, timeout=5) 153 | flagged = ret.json()["results"][0]["flagged"] 154 | except requests.exceptions.RequestException as e: 155 | flagged = False 156 | except KeyError as e: 157 | flagged = False 158 | 159 | return flagged 160 | 161 | 162 | # Flan-t5 trained with HF+FSDP saves corrupted weights for shared embeddings, 163 | # Use this function to make sure it can be correctly loaded. 164 | def clean_flant5_ckpt(ckpt_path): 165 | index_file = os.path.join(ckpt_path, "pytorch_model.bin.index.json") 166 | index_json = json.load(open(index_file, "r")) 167 | 168 | weightmap = index_json["weight_map"] 169 | 170 | share_weight_file = weightmap["shared.weight"] 171 | share_weight = torch.load(os.path.join(ckpt_path, share_weight_file))[ 172 | "shared.weight" 173 | ] 174 | 175 | for weight_name in ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]: 176 | weight_file = weightmap[weight_name] 177 | weight = torch.load(os.path.join(ckpt_path, weight_file)) 178 | weight[weight_name] = share_weight 179 | torch.save(weight, os.path.join(ckpt_path, weight_file)) 180 | 181 | 182 | def pretty_print_semaphore(semaphore): 183 | """Print a semaphore in better format.""" 184 | if semaphore is None: 185 | return "None" 186 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 187 | 188 | 189 | """A javascript function to get url parameters for the gradio web server.""" 190 | get_window_url_params_js = """ 191 | function() { 192 | const params = new URLSearchParams(window.location.search); 193 | url_params = Object.fromEntries(params); 194 | console.log("url_params", url_params); 195 | return url_params; 196 | } 197 | """ 198 | 199 | 200 | def iter_over_async( 201 | async_gen: AsyncGenerator, event_loop: AbstractEventLoop 202 | ) -> Generator: 203 | """ 204 | Convert async generator to sync generator 205 | 206 | :param async_gen: the AsyncGenerator to convert 207 | :param event_loop: the event loop to run on 208 | :returns: Sync generator 209 | """ 210 | ait = async_gen.__aiter__() 211 | 212 | async def get_next(): 213 | try: 214 | obj = await ait.__anext__() 215 | return False, obj 216 | except StopAsyncIteration: 217 | return True, None 218 | 219 | while True: 220 | done, obj = event_loop.run_until_complete(get_next()) 221 | if done: 222 | break 223 | yield obj 224 | 225 | 226 | def detect_language(text: str) -> str: 227 | """Detect the langauge of a string.""" 228 | import polyglot # pip3 install polyglot pyicu pycld2 229 | from polyglot.detect import Detector 230 | from polyglot.detect.base import logger as polyglot_logger 231 | import pycld2 232 | 233 | polyglot_logger.setLevel("ERROR") 234 | 235 | try: 236 | lang_code = Detector(text).language.name 237 | except (pycld2.error, polyglot.detect.base.UnknownLanguage): 238 | lang_code = "unknown" 239 | return lang_code 240 | --------------------------------------------------------------------------------