├── EVAL.md ├── LICENSE ├── README.md ├── cache_generate.py ├── cache_generate_qwen.py ├── convert_rouge_llava.py ├── convert_rouge_qwen.py ├── eval_generate.py ├── eval_latency.py ├── eval_ppl.py ├── eval_ppl_qwen.py ├── eval_rouge.py ├── eval_rouge_qwen.py ├── kv_cache.py ├── kv_cache_qwen.py ├── llava ├── __init__.py ├── constants.py ├── conversation.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── apply_delta.py │ ├── builder.py │ ├── consolidate.py │ ├── language_model │ │ ├── llava_llama.py │ │ ├── llava_mistral.py │ │ ├── llava_mpt.py │ │ └── mpt │ │ │ ├── adapt_tokenizer.py │ │ │ ├── attention.py │ │ │ ├── blocks.py │ │ │ ├── configuration_mpt.py │ │ │ ├── custom_embedding.py │ │ │ ├── flash_attn_triton.py │ │ │ ├── hf_prefixlm_converter.py │ │ │ ├── meta_init_context.py │ │ │ ├── modeling_mpt.py │ │ │ ├── norm.py │ │ │ └── param_init_fns.py │ ├── llava_arch.py │ ├── make_delta.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── clip_encoder.py │ ├── multimodal_projector │ │ └── builder.py │ └── utils.py ├── train │ ├── llama_flash_attn_monkey_patch.py │ ├── llama_xformers_attn_monkey_patch.py │ ├── llava_trainer.py │ ├── train.py │ ├── train_mem.py │ └── train_xformers.py └── utils.py ├── qwen_generation_utils.py ├── requirements.txt └── run.sh /EVAL.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | In Elastic Cache, we evaluate our method on three models(llava 7b/13b and qwen-vl-chat), two datasets(llava_detail_1k and MM-Vet). We use two metrics: PPL and ROUGE. Result will be saved at ./logs\_\\_\/ 4 | 5 | ## Llava-7b 6 | 7 | ### MM-Vet 8 | 9 | To eval PPL, run the following scripts. Method should be elastic/h2o/local 10 | 11 | ```bash 12 | python3 eval_ppl.py\ 13 | --model-path ./models/llava-v1.5-7b \ 14 | --data-path ./playground/data/mm-vet/mm-vet.json \ 15 | --image-path ./playground/data/mm-vet/images \ 16 | --eval-samples 218 \ 17 | --method "elastic" \ 18 | --ratio 0.2 \ 19 | --exp-name "llava-7b-ppl-mmvet" 20 | ``` 21 | 22 | To eval ROUGE, run the following scripts. Method should be elastic/h2o/local 23 | 24 | ```bash 25 | # you should firstly generate refernce texts with full cache, and then run the evaluation 26 | python3 convert_rouge_llava.py \ 27 | --model-path ./models/llava-v1.5-7b \ 28 | --data-path ./playground/data/mm-vet/mm-vet.json \ 29 | --image-path ./playground/data/mm-vet/images 30 | 31 | python3 eval_rouge.py \ 32 | --model-path ./models/llava-v1.5-7b \ 33 | --data-path ./playground/data/mm-vet/rouge-llava-v1.5-7b-mm-vet.json \ 34 | --image-path ./playground/data/mm-vet/images \ 35 | --eval-samples 218 \ 36 | --ratio 0.2 \ 37 | --exp-name "llava-7b-rouge-mmvet" 38 | ``` 39 | 40 | ## Llava-13b 41 | 42 | ### MM-Vet 43 | 44 | To eval PPL, run the following scripts. Method should be elastic/h2o/local 45 | 46 | ```bash 47 | python3 eval_ppl.py\ 48 | --model-path ./models/llava-v1.5-13b \ 49 | --data-path ./playground/data/mm-vet/mm-vet.json \ 50 | --image-path ./playground/data/mm-vet/images \ 51 | --eval-samples 218 \ 52 | --method "elastic" \ 53 | --ratio 0.2 \ 54 | --exp-name "llava-13b-ppl-mmvet" 55 | ``` 56 | 57 | To eval ROUGE, run the following scripts. Method should be elastic/h2o/local 58 | 59 | ```bash 60 | # you should firstly generate refernce texts with full cache, and then run the evaluation 61 | python3 convert_rouge_llava.py \ 62 | --model-path ./models/llava-v1.5-13b \ 63 | --data-path ./playground/data/mm-vet/mm-vet.json \ 64 | --image-path ./playground/data/mm-vet/images 65 | 66 | python3 eval_rouge.py \ 67 | --model-path ./models/llava-v1.5-13b \ 68 | --data-path ./playground/data/mm-vet/rouge-llava-v1.5-13b-mm-vet.json \ 69 | --image-path ./playground/data/mm-vet/images \ 70 | --eval-samples 218 \ 71 | --ratio 0.2 \ 72 | --exp-name "llava-13b-rouge-mmvet" 73 | ``` 74 | 75 | 76 | ## Qwen-VL-Chat 77 | 78 | ### MM-Vet 79 | 80 | To eval PPL, run the following scripts. Method should be elastic/h2o/local 81 | 82 | ```bash 83 | python3 eval_ppl_qwen.py\ 84 | --model-path ./models/qwen-vl-chat \ 85 | --data-path ./playground/data/mm-vet/mm-vet.json \ 86 | --image-path ./playground/data/mm-vet/images \ 87 | --eval-samples 218 \ 88 | --method "elastic" \ 89 | --ratio 0.2 \ 90 | --exp-name "qwen-ppl-mmvet" 91 | ``` 92 | 93 | To eval ROUGE, run the following scripts. Method should be elastic/h2o/local 94 | 95 | ```bash 96 | # you should firstly generate refernce texts with full cache, and then run the evaluation 97 | python3 convert_rouge_qwen.py \ 98 | --model-path ./models/qwen-vl-chat \ 99 | --data-path ./playground/data/mm-vet/mm-vet.json \ 100 | --image-path ./playground/data/mm-vet/images 101 | 102 | python3 eval_rouge_qwen.py \ 103 | --model-path ./models/qwen-vl-chat \ 104 | --data-path ./playground/data/mm-vet/rouge-qwen-vl-chat-mm-vet.json \ 105 | --image-path ./playground/data/mm-vet/images \ 106 | --eval-samples 218 \ 107 | --ratio 0.2 \ 108 | --exp-name "qwen-rouge-mmvet" 109 | ``` 110 | 111 | # Generation 112 | 113 | To generate, you can run 114 | 115 | ```bash 116 | python3 eval_generate.py \ 117 | --model-path ./models/llava-v1.5-13b \ 118 | --data-path ./playground/data/mm-vet/mm-vet.json \ 119 | --image-path ./playground/data/mm-vet/images 120 | --method "elastic" \ 121 | --ratio 0.2 122 | ``` 123 | 124 | # Latency 125 | 126 | To eval latency, you can run 127 | 128 | ```bash 129 | python3 eval_latency.py \ 130 | --model-path ./models/llava-v1.5-13b \ 131 | --batch-size 8 \ 132 | --method "elastic" \ 133 | --ratio 0.2 134 | ``` 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Zuyan Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Efficient Inference of Vision Instruction-Following Models with Elastic Cache 3 | 4 | This repository contains PyTorch implementation for Elastic Cache (ECCV 2024). 5 | 6 | [Project Page](https://sites.google.com/view/elastic-cache) | [arXiv Paper](https://arxiv.org/pdf/2407.18121) 7 | 8 | ## Elastic Cache 9 | 10 |

11 | 8a94931d4958c0665a9bde8e35f1a974.png 12 |

13 | 14 | Instruction encoding accounts for most of the theoretical computation cost, while the actual latency is negligible. This underscores that it’s not just model weights but also the **KV cache** used in output generation that can become a significant bottleneck. 15 | 16 | We propose **Elastic Cache** through a **Cache Merging** based on the importance scores of instruction tokens, complemented by a **fixed-point elimination** strategy in the output generation phase. Our designs yield significant inference acceleration while maintaining generation quality. 17 | 18 | ## Get Started 19 | 20 | 1. **Environmental Setup**: 21 | 22 | We choose LLaVA-1.5 and Qwen-VL as our base model. You can install following dependencies for Elastic Cache evaluation: 23 | 24 | ``` 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | 2. **Initial Weights**: 29 | 30 | We use [LLaVA-1.5-7B](https://huggingface.co/liuhaotian/llava-v1.5-7b), [LLaVA-1.5-13B](https://huggingface.co/liuhaotian/llava-v1.5-13b) and [Qwen-VL](https://huggingface.co/Qwen/Qwen-VL) in our experiements, you may download these models and put them at /path/to/model 31 | 32 | 3. **Download Eval Data**: 33 | 34 | You can download our pre-processed MM-Vet dataset [here](https://drive.google.com/file/d/1MLB7Pr_zo2Nu5iihuXRXE38nHzY-TnRN/view?usp=sharing), and put it at `./playground/data/mm-vet`. Our choosed LLaVA-Description datasets will come soon. 35 | 36 | You can also prepare your own conversations for testing following the format in the json file. 37 | 38 | 4. **Eval** 39 | 40 | Please refer to [EVAL.md](https://github.com/liuzuyan/ElasticCache/blob/main/EVAL.md) for the detailed instructions on evaluation, including generation, PPL evaluation, ROUGE evaluation, and latency test. 41 | 42 | ## Quantitative and Qualitative Results 43 | 44 | We evaluate **Elastic Cache** together with baselines (H2O and StreamingLLM) on PPL (lower better) and ROUGE (higher better) metrics. We conduct LLaVA-1.5 of different sizes (a),(b) and Qwen-VL-7B(c) for visual tasks. Our Elastic Cache outperforms baselines consistently. 45 | 46 |

47 | 30dcc0713f9c3dc40600846aa2037509.png 48 |

49 | 50 | ## Citation 51 | 52 | If you found this repository useful, please consider citing: 53 | 54 | ``` 55 | @article{liu2024elastic, 56 | title={Efficient Inference of Vision Instruction-Following Models with Elastic Cache}, 57 | author={Liu, Zuyan and Liu, Benlin and Wang, Jiahui and Dong, Yuhao and Chen, Guangyi and Rao, Yongming and Krishna, Ranjay and Lu, Jiwen}, 58 | journal={arXiv preprint arXiv:2407.18121}, 59 | year={2024} 60 | } 61 | ``` -------------------------------------------------------------------------------- /convert_rouge_llava.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | import json 5 | device = "cuda" 6 | 7 | import argparse 8 | import torch 9 | import types 10 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 11 | from llava.conversation import conv_templates, SeparatorStyle 12 | from llava.model.builder import load_pretrained_model 13 | from llava.utils import disable_torch_init 14 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 15 | 16 | from PIL import Image 17 | 18 | import requests 19 | from PIL import Image 20 | from io import BytesIO 21 | from transformers import TextStreamer 22 | 23 | def load_image(image_file): 24 | if image_file.startswith('http://') or image_file.startswith('https://'): 25 | response = requests.get(image_file) 26 | image = Image.open(BytesIO(response.content)).convert('RGB') 27 | else: 28 | image = Image.open(image_file).convert('RGB') 29 | return image 30 | 31 | def main(args): 32 | with open(args.data_path, "r") as f: 33 | data = json.load(f) 34 | 35 | outputs_data_json = [] 36 | dataset_name = args.data_path.split('/')[-1].split('.')[0] 37 | model_name = get_model_name_from_path(args.model_path) 38 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 39 | 40 | if 'llama-2' in model_name.lower(): 41 | conv_mode = "llava_llama_2" 42 | elif "v1" in model_name.lower(): 43 | conv_mode = "llava_v1" 44 | elif "mpt" in model_name.lower(): 45 | conv_mode = "mpt" 46 | else: 47 | conv_mode = "llava_v0" 48 | 49 | if args.conv_mode is not None and conv_mode != args.conv_mode: 50 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 51 | else: 52 | args.conv_mode = conv_mode 53 | 54 | data = data[:args.eval_samples] 55 | 56 | for item in tqdm(data): 57 | conv = conv_templates[args.conv_mode].copy() 58 | if "mpt" in model_name.lower(): 59 | roles = ('user', 'assistant') 60 | else: 61 | roles = conv.roles 62 | image_path = os.path.join(args.image_path, item["image"]) 63 | question = item['question'] 64 | if "mm-vet" in args.data_path: 65 | question = question + '\n' + DEFAULT_IMAGE_TOKEN 66 | answer = item['answer'] 67 | 68 | output_data_json = {} 69 | output_data_json['image'] = item["image"] 70 | output_data_json['question'] = question 71 | 72 | 73 | image = load_image(image_path) 74 | image_tensor = process_images([image], image_processor, args) 75 | image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16) 76 | 77 | conv.append_message(conv.roles[0], question) 78 | conv.append_message(conv.roles[1], None) 79 | prompt = conv.get_prompt() 80 | 81 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 82 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 83 | keywords = [stop_str] 84 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 85 | 86 | answer_ids = tokenizer.encode(answer, return_tensors='pt').cuda()[:, 1:] 87 | past_key_values = None 88 | 89 | num_of_token = 0 90 | output_ids = model.generate( 91 | input_ids, 92 | images=image_tensor, 93 | do_sample=True if args.temperature > 0 else False, 94 | temperature=args.temperature, 95 | top_p=args.top_p, 96 | num_beams=args.num_beams, 97 | max_new_tokens=1024, 98 | use_cache=True, 99 | stopping_criteria=[stopping_criteria]) 100 | 101 | outputs_generate = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 102 | output_data_json['answer'] = outputs_generate 103 | outputs_data_json.append(output_data_json) 104 | 105 | with open('./playground/data/' + dataset_name + '/rouge-' + model_name + '-' + dataset_name+'.json', 'w') as f: 106 | json.dump(outputs_data_json, f, indent=4) 107 | 108 | if __name__ == "__main__": 109 | parser = argparse.ArgumentParser() 110 | parser.add_argument("--model-path", type=str, default="./models/llava-v1.5-7b") 111 | parser.add_argument("--model-base", type=str, default=None) 112 | parser.add_argument("--data-path", type=str, default="./playground/data/detail_1k.json") 113 | parser.add_argument("--image-path", type=str, default="./playground/data") 114 | parser.add_argument("--device", type=str, default="cuda") 115 | parser.add_argument("--conv-mode", type=str, default=None) 116 | parser.add_argument("--temperature", type=float, default=0) 117 | parser.add_argument("--max-new-tokens", type=int, default=512) 118 | parser.add_argument("--num-chunks", type=int, default=1) 119 | parser.add_argument("--chunk-idx", type=int, default=0) 120 | parser.add_argument("--top_p", type=float, default=None) 121 | parser.add_argument("--num_beams", type=int, default=1) 122 | parser.add_argument("--load-8bit", action="store_true") 123 | parser.add_argument("--load-4bit", action="store_true") 124 | parser.add_argument("--debug", action="store_true") 125 | parser.add_argument("--image-aspect-ratio", type=str, default='pad') 126 | parser.add_argument("--start-size", type=int, default=1) 127 | parser.add_argument("--recent-size", type=int, default=2047) 128 | parser.add_argument("--eval-samples", type=int, default=218) 129 | parser.add_argument("--exp-name", type=str, default='') 130 | args = parser.parse_args() 131 | main(args) 132 | -------------------------------------------------------------------------------- /convert_rouge_qwen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | from torch.nn import CrossEntropyLoss 5 | import json 6 | device = "cuda" 7 | 8 | import argparse 9 | import torch 10 | from cache_generate import generate, sample, greedy_search 11 | import types 12 | from qwen_generation_utils import make_context 13 | from rouge import Rouge 14 | from PIL import Image 15 | 16 | import requests 17 | from PIL import Image 18 | from io import BytesIO 19 | from transformers import TextStreamer 20 | from transformers import AutoModelForCausalLM, AutoTokenizer 21 | 22 | def load_image(image_file): 23 | if image_file.startswith('http://') or image_file.startswith('https://'): 24 | response = requests.get(image_file) 25 | image = Image.open(BytesIO(response.content)).convert('RGB') 26 | else: 27 | image = Image.open(image_file).convert('RGB') 28 | return image 29 | 30 | def main(args): 31 | with open(args.data_path, "r") as f: 32 | data = json.load(f) 33 | 34 | outputs_data_json = [] 35 | model_name = args.model_path.split('/')[-1] 36 | dataset_name = args.data_path.split('/')[-1].split('.')[0] 37 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) 38 | model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="cuda", trust_remote_code=True, bf16=True).eval() 39 | 40 | os.makedirs('logs_temp/', exist_ok=True) 41 | 42 | data = data[:args.eval_samples] 43 | 44 | for item in tqdm(data): 45 | image_path = os.path.join(args.image_path, item["image"]) 46 | question = item['question'] 47 | answer = item['answer'] 48 | 49 | output_data_json = {} 50 | output_data_json['image'] = item['image'] 51 | output_data_json['question'] = question 52 | 53 | if "detail_1k" in args.data_path: 54 | question = question.replace('', '') 55 | 56 | query = tokenizer.from_list_format([ 57 | {'image': image_path}, 58 | {'text': question} 59 | ]) 60 | 61 | raw_text, context_tokens = make_context( 62 | tokenizer, 63 | query, 64 | history=None, 65 | system="You are a helpful assistant.", 66 | max_window_size=None, 67 | chat_format='chatml', 68 | ) 69 | 70 | input_ids = torch.tensor([context_tokens]).cuda() 71 | answer_ids = tokenizer.encode(answer, return_tensors='pt').cuda()[:, 1:] 72 | past_key_values = None 73 | 74 | num_of_token = 0 75 | output_ids = model.generate( 76 | input_ids, 77 | do_sample=True if args.temperature > 0 else False, 78 | temperature=args.temperature, 79 | top_p=args.top_p, 80 | num_beams=args.num_beams, 81 | max_new_tokens=1024, 82 | use_cache=True) 83 | 84 | outputs_generate = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 85 | output_data_json['answer'] = outputs_generate 86 | outputs_data_json.append(output_data_json) 87 | 88 | with open('./playground/data/' + dataset_name + '/rouge-' + model_name + '-' + dataset_name + '.json', 'w') as f: 89 | json.dump(outputs_data_json, f, indent=4) 90 | 91 | if __name__ == "__main__": 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("--model-path", type=str, default="./models/qwen-vl-chat") 94 | parser.add_argument("--model-base", type=str, default=None) 95 | parser.add_argument("--data-path", type=str, default="./playground/data/detail_1k/detail_1k.json") 96 | parser.add_argument("--image-path", type=str, default="./playground/data/detail_1k/") 97 | parser.add_argument("--device", type=str, default="cuda") 98 | parser.add_argument("--conv-mode", type=str, default=None) 99 | parser.add_argument("--temperature", type=float, default=0) 100 | parser.add_argument("--max-new-tokens", type=int, default=512) 101 | parser.add_argument("--num-chunks", type=int, default=1) 102 | parser.add_argument("--chunk-idx", type=int, default=0) 103 | parser.add_argument("--top_p", type=float, default=None) 104 | parser.add_argument("--num_beams", type=int, default=1) 105 | parser.add_argument("--load-8bit", action="store_true") 106 | parser.add_argument("--load-4bit", action="store_true") 107 | parser.add_argument("--debug", action="store_true") 108 | parser.add_argument("--image-aspect-ratio", type=str, default='pad') 109 | parser.add_argument("--start-size", type=int, default=1) 110 | parser.add_argument("--recent-size", type=int, default=2047) 111 | parser.add_argument("--eval-samples", type=int, default=218) 112 | parser.add_argument("--exp-name", type=str, default='') 113 | parser.add_argument("--method", type=str, default='elastic') 114 | args = parser.parse_args() 115 | main(args) 116 | -------------------------------------------------------------------------------- /eval_generate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | from torch.nn import CrossEntropyLoss 5 | from kv_cache import ElasticCache, LocalCache, H2OCache 6 | import json 7 | device = "cuda" 8 | 9 | import argparse 10 | import torch 11 | 12 | from cache_generate import generate, sample, greedy_search 13 | import types 14 | 15 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 16 | from llava.conversation import conv_templates, SeparatorStyle 17 | from llava.model.builder import load_pretrained_model 18 | from llava.utils import disable_torch_init 19 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 20 | 21 | 22 | from PIL import Image 23 | 24 | import requests 25 | from PIL import Image 26 | from io import BytesIO 27 | from transformers import TextStreamer 28 | import numpy as np 29 | 30 | def load_image(image_file): 31 | if image_file.startswith('http://') or image_file.startswith('https://'): 32 | response = requests.get(image_file) 33 | image = Image.open(BytesIO(response.content)).convert('RGB') 34 | else: 35 | image = Image.open(image_file).convert('RGB') 36 | return image 37 | 38 | def main(args): 39 | print(args.method) 40 | with open(args.data_path, "r") as f: 41 | data = json.load(f) 42 | 43 | model_name = get_model_name_from_path(args.model_path) 44 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 45 | 46 | model.generate = types.MethodType(generate, model) 47 | model.sample = types.MethodType(sample, model) 48 | model.greedy_search = types.MethodType(greedy_search, model) 49 | 50 | if 'llama-2' in model_name.lower(): 51 | conv_mode = "llava_llama_2" 52 | elif "v1" in model_name.lower(): 53 | conv_mode = "llava_v1" 54 | elif "mpt" in model_name.lower(): 55 | conv_mode = "mpt" 56 | else: 57 | conv_mode = "llava_v0" 58 | 59 | if args.conv_mode is not None and conv_mode != args.conv_mode: 60 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 61 | else: 62 | args.conv_mode = conv_mode 63 | 64 | k_seq_dim = v_seq_dim = 2 65 | data = data[:args.eval_samples] 66 | 67 | for item in (data): 68 | if args.method == "elastic": 69 | kv_cache = ElasticCache( 70 | start_size=args.start_size, 71 | recent_size=args.recent_size, 72 | k_seq_dim=k_seq_dim, 73 | v_seq_dim=v_seq_dim, 74 | ratio=args.ratio, 75 | layer_num=32 if "7b" in model_name else 40 76 | ) 77 | elif args.method == "local": 78 | kv_cache = LocalCache( 79 | start_size=args.start_size, 80 | recent_size=args.recent_size, 81 | k_seq_dim=k_seq_dim, 82 | v_seq_dim=v_seq_dim, 83 | ratio=args.ratio 84 | ) 85 | elif args.method == "h2o": 86 | kv_cache = H2OCache( 87 | start_size=args.start_size, 88 | recent_size=args.recent_size, 89 | k_seq_dim=k_seq_dim, 90 | v_seq_dim=v_seq_dim, 91 | ratio=args.ratio 92 | ) 93 | conv = conv_templates[args.conv_mode].copy() 94 | if "mpt" in model_name.lower(): 95 | roles = ('user', 'assistant') 96 | else: 97 | roles = conv.roles 98 | 99 | image_path = os.path.join(args.image_path, item['image']) 100 | question = item['question'] 101 | answer = item['answer'] 102 | if "mm-vet" in args.data_path: 103 | question = question + '\n' + DEFAULT_IMAGE_TOKEN 104 | 105 | image = load_image(image_path) 106 | import pdb; pdb.set_trace() 107 | image_tensor = process_images([image], image_processor, args) 108 | image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16) 109 | 110 | conv.append_message(conv.roles[0], question) 111 | conv.append_message(conv.roles[1], None) 112 | prompt = conv.get_prompt() 113 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 114 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 115 | keywords = [stop_str] 116 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 117 | try: 118 | kv_cache.score_sum = torch.zeros_like(kv_cache.score_sum).cuda() 119 | kv_cache.flag = True 120 | except: 121 | print('cannot reset kv_cache') 122 | pass 123 | 124 | with torch.inference_mode(): 125 | output_ids = model.generate( 126 | input_ids, 127 | images=image_tensor, 128 | do_sample=True, 129 | temperature=args.temperature, 130 | top_p=args.top_p, 131 | num_beams=args.num_beams, 132 | max_new_tokens=512, 133 | use_cache=True, 134 | stopping_criteria=[stopping_criteria], 135 | kv_cache_criteria=kv_cache) 136 | 137 | outputs_generate = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 138 | print("output:", outputs_generate) 139 | 140 | 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument("--model-path", type=str, default="./models/llava-v1.5-7b") 145 | parser.add_argument("--model-base", type=str, default=None) 146 | parser.add_argument("--data-path", type=str, default="./playground/data/detail_1k/detail_1k.json") 147 | parser.add_argument("--image-path", type=str, default="./playground/data/detail_1k") 148 | parser.add_argument("--device", type=str, default="cuda") 149 | parser.add_argument("--conv-mode", type=str, default=None) 150 | parser.add_argument("--temperature", type=float, default=0.7) 151 | parser.add_argument("--max-new-tokens", type=int, default=512) 152 | parser.add_argument("--num-chunks", type=int, default=1) 153 | parser.add_argument("--chunk-idx", type=int, default=0) 154 | parser.add_argument("--top_p", type=float, default=None) 155 | parser.add_argument("--num_beams", type=int, default=1) 156 | parser.add_argument("--load-8bit", action="store_true") 157 | parser.add_argument("--load-4bit", action="store_true") 158 | parser.add_argument("--debug", action="store_true") 159 | parser.add_argument("--image-aspect-ratio", type=str, default='pad') 160 | parser.add_argument("--start-size", type=int, default=1) 161 | parser.add_argument("--recent-size", type=int, default=2047) 162 | parser.add_argument("--eval-samples", type=int, default=1) 163 | parser.add_argument("--exp-name", type=str, default='') 164 | parser.add_argument("--method", type=str, default="elastic") 165 | parser.add_argument("--ratio", type=float, default=0.2) 166 | args = parser.parse_args() 167 | main(args) 168 | -------------------------------------------------------------------------------- /eval_latency.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | from torch.nn import CrossEntropyLoss 5 | from kv_cache import ElasticCache, LocalCache, H2OCache 6 | import json 7 | device = "cuda" 8 | import time 9 | import argparse 10 | import torch 11 | 12 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 13 | from llava.conversation import conv_templates, SeparatorStyle 14 | from llava.model.builder import load_pretrained_model 15 | from llava.utils import disable_torch_init 16 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 17 | 18 | 19 | from PIL import Image 20 | 21 | import requests 22 | from PIL import Image 23 | from io import BytesIO 24 | from transformers import TextStreamer 25 | 26 | def load_image(image_file): 27 | if image_file.startswith('http://') or image_file.startswith('https://'): 28 | response = requests.get(image_file) 29 | image = Image.open(BytesIO(response.content)).convert('RGB') 30 | else: 31 | image = Image.open(image_file).convert('RGB') 32 | return image 33 | 34 | def main(args): 35 | model_name = get_model_name_from_path(args.model_path) 36 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 37 | if 'llama-2' in model_name.lower(): 38 | conv_mode = "llava_llama_2" 39 | elif "v1" in model_name.lower(): 40 | conv_mode = "llava_v1" 41 | elif "mpt" in model_name.lower(): 42 | conv_mode = "mpt" 43 | else: 44 | conv_mode = "llava_v0" 45 | 46 | if args.conv_mode is not None and conv_mode != args.conv_mode: 47 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 48 | else: 49 | args.conv_mode = conv_mode 50 | 51 | past_key_values = None 52 | 53 | k_seq_dim = v_seq_dim = 2 54 | os.makedirs('logs/', exist_ok=True) 55 | 56 | if args.method == "elastic": 57 | kv_cache = ElasticCache( 58 | start_size=args.start_size, 59 | recent_size=args.recent_size, 60 | k_seq_dim=k_seq_dim, 61 | v_seq_dim=v_seq_dim, 62 | ratio=args.ratio, 63 | layer_num=(32 if "7b" in model_name else 40) * args.batch_size 64 | ) 65 | elif args.method == "local": 66 | kv_cache = LocalCache( 67 | start_size=args.start_size, 68 | recent_size=args.recent_size, 69 | k_seq_dim=k_seq_dim, 70 | v_seq_dim=v_seq_dim, 71 | ratio=args.ratio 72 | ) 73 | elif args.method == "h2o": 74 | kv_cache = H2OCache( 75 | start_size=args.start_size, 76 | recent_size=args.recent_size, 77 | k_seq_dim=k_seq_dim, 78 | v_seq_dim=v_seq_dim, 79 | ratio=args.ratio 80 | ) 81 | 82 | input_ids = torch.ones([1, 900], dtype=int).cuda() 83 | answer_ids = torch.ones([1, 512], dtype=int).cuda() 84 | past_key_values = None 85 | try: 86 | kv_cache.score_sum = torch.zeros_like(kv_cache.score_sum).cuda() 87 | kv_cache.flag = True 88 | except: 89 | print('cannot reset kv_cache') 90 | pass 91 | num_of_token = 0 92 | start_time = time.time() 93 | print("start here") 94 | for idx in range(0, answer_ids.shape[-1] - 1): 95 | with torch.no_grad(): 96 | if past_key_values is None: 97 | time2 = time.time() 98 | outputs = model( 99 | input_ids.repeat(args.batch_size, 1), 100 | images=None, 101 | past_key_values=past_key_values, 102 | use_cache=True, 103 | output_attentions=True, 104 | ) 105 | logits = outputs.logits.view(args.batch_size, -1, model.config.vocab_size) 106 | num_of_token += logits.shape[1] 107 | past_key_values = outputs.past_key_values 108 | attentions = outputs.attentions 109 | 110 | if kv_cache is not None: 111 | past_key_values = kv_cache(past_key_values, num_of_token, attentions) 112 | time3 = time.time() 113 | print('time: ', time3 - time2) 114 | else: 115 | 116 | cur_input_ids = answer_ids[:, idx - 1: idx] 117 | outputs = model( 118 | cur_input_ids.repeat(args.batch_size, 1), 119 | past_key_values=past_key_values, 120 | use_cache=True, 121 | output_attentions=True, 122 | ) 123 | logits = outputs.logits.view(args.batch_size, -1, model.config.vocab_size) 124 | num_of_token += logits.shape[1] 125 | past_key_values = outputs.past_key_values 126 | attentions = outputs.attentions 127 | 128 | if kv_cache is not None: 129 | past_key_values = kv_cache(past_key_values, num_of_token, attentions) 130 | end_time = time.time() 131 | print('time: ', end_time - start_time) 132 | 133 | 134 | if __name__ == "__main__": 135 | parser = argparse.ArgumentParser() 136 | parser.add_argument("--model-path", type=str, default="/home/stu6/models/llava-v1.5-7b") 137 | parser.add_argument("--model-base", type=str, default=None) 138 | parser.add_argument("--device", type=str, default="cuda") 139 | parser.add_argument("--conv-mode", type=str, default=None) 140 | parser.add_argument("--temperature", type=float, default=0.2) 141 | parser.add_argument("--max-new-tokens", type=int, default=512) 142 | parser.add_argument("--load-8bit", action="store_true") 143 | parser.add_argument("--load-4bit", action="store_true") 144 | parser.add_argument("--debug", action="store_true") 145 | parser.add_argument("--image-aspect-ratio", type=str, default='pad') 146 | parser.add_argument("--start-size", type=int, default=1) 147 | parser.add_argument("--recent-size", type=int, default=2047) 148 | parser.add_argument("--exp-name", type=str, default='') 149 | parser.add_argument("--batch-size", type=int, default=8) 150 | parser.add_argument("--method", type=str, default="elastic") 151 | parser.add_argument("--ratio", type=float, default=0.2) 152 | args = parser.parse_args() 153 | main(args) 154 | -------------------------------------------------------------------------------- /eval_ppl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | from torch.nn import CrossEntropyLoss 5 | from kv_cache import ElasticCache, LocalCache, H2OCache 6 | import json 7 | device = "cuda" 8 | 9 | import argparse 10 | import torch 11 | 12 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 13 | from llava.conversation import conv_templates, SeparatorStyle 14 | from llava.model.builder import load_pretrained_model 15 | from llava.utils import disable_torch_init 16 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 17 | from PIL import Image 18 | 19 | import requests 20 | from PIL import Image 21 | from io import BytesIO 22 | from transformers import TextStreamer 23 | 24 | def load_image(image_file): 25 | if image_file.startswith('http://') or image_file.startswith('https://'): 26 | response = requests.get(image_file) 27 | image = Image.open(BytesIO(response.content)).convert('RGB') 28 | else: 29 | image = Image.open(image_file).convert('RGB') 30 | return image 31 | 32 | def main(args): 33 | 34 | with open(args.data_path, "r") as f: 35 | data = json.load(f) 36 | 37 | model_name = get_model_name_from_path(args.model_path) 38 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 39 | 40 | if 'llama-2' in model_name.lower(): 41 | conv_mode = "llava_llama_2" 42 | elif "v1" in model_name.lower(): 43 | conv_mode = "llava_v1" 44 | elif "mpt" in model_name.lower(): 45 | conv_mode = "mpt" 46 | else: 47 | conv_mode = "llava_v0" 48 | 49 | if args.conv_mode is not None and conv_mode != args.conv_mode: 50 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 51 | else: 52 | args.conv_mode = conv_mode 53 | 54 | loss_fn = CrossEntropyLoss(reduction="none") 55 | past_key_values = None 56 | 57 | k_seq_dim = v_seq_dim = 2 58 | 59 | nlls = [] 60 | os.makedirs('logs_ppl_llava/', exist_ok=True) 61 | data = data[:args.eval_samples] 62 | 63 | for item in tqdm(data): 64 | if args.method == "elastic": 65 | kv_cache = ElasticCache( 66 | start_size=args.start_size, 67 | recent_size=args.recent_size, 68 | k_seq_dim=k_seq_dim, 69 | v_seq_dim=v_seq_dim, 70 | ratio=args.ratio, 71 | layer_num=32 if "7b" in model_name else 40, 72 | ) 73 | elif args.method == "local": 74 | kv_cache = LocalCache( 75 | start_size=args.start_size, 76 | recent_size=args.recent_size, 77 | k_seq_dim=k_seq_dim, 78 | v_seq_dim=v_seq_dim, 79 | ratio=args.ratio, 80 | ) 81 | elif args.method == "h2o": 82 | kv_cache = H2OCache( 83 | start_size=args.start_size, 84 | recent_size=args.recent_size, 85 | k_seq_dim=k_seq_dim, 86 | v_seq_dim=v_seq_dim, 87 | ratio=args.ratio, 88 | ) 89 | else: 90 | raise ValueError("Invalid method") 91 | 92 | conv = conv_templates[args.conv_mode].copy() 93 | if "mpt" in model_name.lower(): 94 | roles = ('user', 'assistant') 95 | else: 96 | roles = conv.roles 97 | image_path = os.path.join(args.image_path, item["image"]) 98 | question = item['question'] 99 | if "mm-vet" in args.data_path: 100 | question = question + '\n' + DEFAULT_IMAGE_TOKEN 101 | answer = item['answer'] 102 | 103 | image = load_image(image_path) 104 | image_tensor = process_images([image], image_processor, args) 105 | image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16) 106 | 107 | conv.append_message(conv.roles[0], question) 108 | conv.append_message(conv.roles[1], None) 109 | prompt = conv.get_prompt() 110 | 111 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 112 | 113 | answer_ids = tokenizer.encode(answer, return_tensors='pt').cuda()[:, 1:] 114 | past_key_values = None 115 | try: 116 | kv_cache.score_sum = torch.zeros_like(kv_cache.score_sum).cuda() 117 | kv_cache.flag = True 118 | except: 119 | print('cannot reset kv_cache') 120 | pass 121 | num_of_token = 0 122 | 123 | for idx in range(0, answer_ids.shape[-1] - 1): 124 | 125 | with torch.no_grad(): 126 | if past_key_values is None: 127 | 128 | outputs = model( 129 | input_ids, 130 | images=image_tensor, 131 | past_key_values=past_key_values, 132 | use_cache=True, 133 | output_attentions=True, 134 | ) 135 | logits = outputs.logits.view(-1, model.config.vocab_size) 136 | num_of_token += logits.shape[0] 137 | past_key_values = outputs.past_key_values 138 | attentions = outputs.attentions 139 | 140 | logits = logits[-1].view(-1, model.config.vocab_size) 141 | label = answer_ids[:, idx : idx + 1].to(logits.device).view(-1) 142 | neg_log_likelihood = loss_fn(logits, label) 143 | if kv_cache is not None: 144 | past_key_values = kv_cache(past_key_values, num_of_token, attentions) 145 | else: 146 | 147 | cur_input_ids = answer_ids[:, idx - 1: idx] 148 | outputs = model( 149 | cur_input_ids, 150 | past_key_values=past_key_values, 151 | use_cache=True, 152 | output_attentions=True, 153 | ) 154 | logits = outputs.logits.view(-1, model.config.vocab_size) 155 | num_of_token += logits.shape[0] 156 | past_key_values = outputs.past_key_values 157 | attentions = outputs.attentions 158 | 159 | label = answer_ids[:, idx : idx + 1].to(logits.device).view(-1) 160 | neg_log_likelihood = loss_fn(logits, label) 161 | if kv_cache is not None: 162 | past_key_values = kv_cache(past_key_values, num_of_token, attentions) 163 | 164 | nlls.append(neg_log_likelihood) 165 | 166 | ppl = torch.exp(torch.stack(nlls).mean()) 167 | with open(f"logs_ppl_llava/{args.exp_name}.txt", "a") as f: 168 | f.write(f"{ppl.item()}\n") 169 | 170 | if __name__ == "__main__": 171 | parser = argparse.ArgumentParser() 172 | parser.add_argument("--model-path", type=str, default="./models/llava-v1.5-7b") 173 | parser.add_argument("--model-base", type=str, default=None) 174 | parser.add_argument("--data-path", type=str, default="./playground/data/mm-vet/mm-vet.json") 175 | parser.add_argument("--image-path", type=str, default="./playground/data/mm-vet/images") 176 | parser.add_argument("--device", type=str, default="cuda") 177 | parser.add_argument("--conv-mode", type=str, default=None) 178 | parser.add_argument("--temperature", type=float, default=0.2) 179 | parser.add_argument("--max-new-tokens", type=int, default=512) 180 | parser.add_argument("--load-8bit", action="store_true") 181 | parser.add_argument("--load-4bit", action="store_true") 182 | parser.add_argument("--debug", action="store_true") 183 | parser.add_argument("--image-aspect-ratio", type=str, default='pad') 184 | parser.add_argument("--start-size", type=int, default=1) 185 | parser.add_argument("--recent-size", type=int, default=2047) 186 | parser.add_argument("--eval-samples", type=int, default=100) 187 | parser.add_argument("--exp-name", type=str, default='llava-7b') 188 | parser.add_argument("--method", type=str, default="elastic") 189 | parser.add_argument("--ratio", type=float, default=0.2) 190 | args = parser.parse_args() 191 | main(args) 192 | -------------------------------------------------------------------------------- /eval_ppl_qwen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | from torch.nn import CrossEntropyLoss 5 | from kv_cache_qwen import ElasticCache, LocalCache, H2OCache 6 | import json 7 | device = "cuda" 8 | import argparse 9 | import torch 10 | from qwen_generation_utils import make_context 11 | from PIL import Image 12 | import requests 13 | from PIL import Image 14 | from io import BytesIO 15 | from transformers import TextStreamer 16 | from transformers import AutoModelForCausalLM, AutoTokenizer 17 | 18 | def load_image(image_file): 19 | if image_file.startswith('http://') or image_file.startswith('https://'): 20 | response = requests.get(image_file) 21 | image = Image.open(BytesIO(response.content)).convert('RGB') 22 | else: 23 | image = Image.open(image_file).convert('RGB') 24 | return image 25 | 26 | def main(args): 27 | 28 | 29 | with open(args.data_path, "r") as f: 30 | data = json.load(f) 31 | 32 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) 33 | model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="cuda", trust_remote_code=True, bf16=True).eval() 34 | 35 | loss_fn = CrossEntropyLoss(reduction="none") 36 | past_key_values = None 37 | 38 | k_seq_dim = v_seq_dim = 1 39 | 40 | os.makedirs('logs_ppl_qwen/', exist_ok=True) 41 | 42 | data = data[:args.eval_samples] 43 | 44 | nlls = [] 45 | 46 | for item in tqdm(data): 47 | if args.method == "elastic": 48 | kv_cache = ElasticCache( 49 | start_size=args.start_size, 50 | recent_size=args.recent_size, 51 | k_seq_dim=k_seq_dim, 52 | v_seq_dim=v_seq_dim, 53 | ratio=args.ratio, 54 | layer_num=32, 55 | ) 56 | elif args.method == "local": 57 | kv_cache = LocalCache( 58 | start_size=args.start_size, 59 | recent_size=args.recent_size, 60 | k_seq_dim=k_seq_dim, 61 | v_seq_dim=v_seq_dim, 62 | ratio=args.ratio, 63 | ) 64 | elif args.method == "h2o": 65 | kv_cache = H2OCache( 66 | start_size=args.start_size, 67 | recent_size=args.recent_size, 68 | k_seq_dim=k_seq_dim, 69 | v_seq_dim=v_seq_dim, 70 | ratio=args.ratio, 71 | ) 72 | image_path = os.path.join(args.image_path, item["image"]) 73 | question = item['question'] 74 | answer = item['answer'] 75 | if "detail_1k" in args.data_path: 76 | question = question.replace('', '') 77 | 78 | query = tokenizer.from_list_format([ 79 | {'image': image_path}, 80 | {'text': question} 81 | ]) 82 | 83 | raw_text, context_tokens = make_context( 84 | tokenizer, 85 | query, 86 | history=None, 87 | system="You are a helpful assistant.", 88 | max_window_size=None, 89 | chat_format='chatml', 90 | ) 91 | input_ids = torch.tensor([context_tokens]).cuda() 92 | 93 | answer_ids = tokenizer.encode(answer, return_tensors='pt').cuda()[:, 1:] 94 | past_key_values = None 95 | num_of_token = 0 96 | 97 | for idx in range(0, answer_ids.shape[-1] - 1): 98 | 99 | with torch.no_grad(): 100 | if past_key_values is None: 101 | outputs = model( 102 | input_ids, 103 | past_key_values=past_key_values, 104 | use_cache=True, 105 | output_attentions=True, 106 | ) 107 | logits = outputs.logits.view(-1, model.config.vocab_size) 108 | num_of_token += logits.shape[0] 109 | past_key_values = outputs.past_key_values 110 | attentions = outputs.attentions 111 | 112 | logits = logits[-1].view(-1, model.config.vocab_size) 113 | label = answer_ids[:, idx : idx + 1].to(logits.device).view(-1) 114 | neg_log_likelihood = loss_fn(logits, label) 115 | if kv_cache is not None: 116 | past_key_values = kv_cache(past_key_values, num_of_token, attentions) 117 | else: 118 | cur_input_ids = answer_ids[:, idx - 1: idx] 119 | outputs = model( 120 | cur_input_ids, 121 | past_key_values=past_key_values, 122 | use_cache=True, 123 | output_attentions=True, 124 | attention_mask=(cur_input_ids != 0), 125 | ) 126 | logits = outputs.logits.view(-1, model.config.vocab_size) 127 | num_of_token += logits.shape[0] 128 | past_key_values = outputs.past_key_values 129 | attentions = outputs.attentions 130 | 131 | label = answer_ids[:, idx : idx + 1].to(logits.device).view(-1) 132 | neg_log_likelihood = loss_fn(logits, label) 133 | if kv_cache is not None: 134 | past_key_values = kv_cache(past_key_values, num_of_token, attentions) 135 | 136 | nlls.append(neg_log_likelihood) 137 | 138 | ppl = torch.exp(torch.stack(nlls).mean()) 139 | with open(f"logs_ppl_qwen/{args.exp_name}.txt", "a") as f: 140 | f.write(f"{ppl.item()}\n") 141 | 142 | if __name__ == "__main__": 143 | parser = argparse.ArgumentParser() 144 | parser.add_argument("--model-path", type=str, default="./models/qwen-vl-chat") 145 | parser.add_argument("--model-base", type=str, default=None) 146 | parser.add_argument("--data-path", type=str, default="./playground/data/mm-vet/mm-vet.json") 147 | parser.add_argument("--image-path", type=str, default="./playground/data/mm-vet/images/") 148 | parser.add_argument("--device", type=str, default="cuda") 149 | parser.add_argument("--conv-mode", type=str, default=None) 150 | parser.add_argument("--temperature", type=float, default=0.2) 151 | parser.add_argument("--max-new-tokens", type=int, default=512) 152 | parser.add_argument("--load-8bit", action="store_true") 153 | parser.add_argument("--load-4bit", action="store_true") 154 | parser.add_argument("--debug", action="store_true") 155 | parser.add_argument("--image-aspect-ratio", type=str, default='pad') 156 | parser.add_argument("--start-size", type=int, default=1) 157 | parser.add_argument("--recent-size", type=int, default=2047) 158 | parser.add_argument("--eval-samples", type=int, default=218) 159 | parser.add_argument("--exp-name", type=str, default='') 160 | parser.add_argument("--method", type=str, default="elastic") 161 | parser.add_argument("--ratio", type=float, default=0.2) 162 | args = parser.parse_args() 163 | main(args) 164 | -------------------------------------------------------------------------------- /eval_rouge.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | from kv_cache import ElasticCache, LocalCache, H2OCache 5 | import json 6 | device = "cuda" 7 | 8 | import argparse 9 | import torch 10 | 11 | from cache_generate import generate, sample, greedy_search 12 | import types 13 | 14 | from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 15 | from llava.conversation import conv_templates, SeparatorStyle 16 | from llava.model.builder import load_pretrained_model 17 | from llava.utils import disable_torch_init 18 | from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria 19 | 20 | from rouge import Rouge 21 | from PIL import Image 22 | 23 | import requests 24 | from PIL import Image 25 | from io import BytesIO 26 | from transformers import TextStreamer 27 | 28 | def load_image(image_file): 29 | if image_file.startswith('http://') or image_file.startswith('https://'): 30 | response = requests.get(image_file) 31 | image = Image.open(BytesIO(response.content)).convert('RGB') 32 | else: 33 | image = Image.open(image_file).convert('RGB') 34 | return image 35 | 36 | def main(args): 37 | with open(args.data_path, "r") as f: 38 | data = json.load(f) 39 | 40 | 41 | model_name = get_model_name_from_path(args.model_path) 42 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device) 43 | 44 | model.generate = types.MethodType(generate, model) 45 | model.sample = types.MethodType(sample, model) 46 | model.greedy_search = types.MethodType(greedy_search, model) 47 | 48 | if 'llama-2' in model_name.lower(): 49 | conv_mode = "llava_llama_2" 50 | elif "v1" in model_name.lower(): 51 | conv_mode = "llava_v1" 52 | elif "mpt" in model_name.lower(): 53 | conv_mode = "mpt" 54 | else: 55 | conv_mode = "llava_v0" 56 | 57 | if args.conv_mode is not None and conv_mode != args.conv_mode: 58 | print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode)) 59 | else: 60 | args.conv_mode = conv_mode 61 | 62 | k_seq_dim = v_seq_dim = 2 63 | 64 | os.makedirs('logs_rouge_llava/', exist_ok=True) 65 | 66 | data = data[:args.eval_samples] 67 | 68 | score_all = [] 69 | 70 | for item in tqdm(data): 71 | if args.method == "elastic": 72 | kv_cache = ElasticCache( 73 | start_size=args.start_size, 74 | recent_size=args.recent_size, 75 | k_seq_dim=k_seq_dim, 76 | v_seq_dim=v_seq_dim, 77 | ratio=args.ratio, 78 | layer_num=32 if "7b" in model_name else 40, 79 | ) 80 | elif args.method == "local": 81 | kv_cache = LocalCache( 82 | start_size=args.start_size, 83 | recent_size=args.recent_size, 84 | k_seq_dim=k_seq_dim, 85 | v_seq_dim=v_seq_dim, 86 | ratio=args.ratio, 87 | ) 88 | elif args.method == "h2o": 89 | kv_cache = H2OCache( 90 | start_size=args.start_size, 91 | recent_size=args.recent_size, 92 | k_seq_dim=k_seq_dim, 93 | v_seq_dim=v_seq_dim, 94 | ratio=args.ratio, 95 | ) 96 | 97 | conv = conv_templates[args.conv_mode].copy() 98 | if "mpt" in model_name.lower(): 99 | roles = ('user', 'assistant') 100 | else: 101 | roles = conv.roles 102 | image_path = os.path.join(args.image_path, item["image"]) 103 | question = item['question'] 104 | answer = item['answer'] 105 | 106 | image = load_image(image_path) 107 | image_tensor = process_images([image], image_processor, args) 108 | image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16) 109 | 110 | conv.append_message(conv.roles[0], question) 111 | conv.append_message(conv.roles[1], None) 112 | prompt = conv.get_prompt() 113 | 114 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() 115 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 116 | keywords = [stop_str] 117 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 118 | 119 | answer_ids = tokenizer.encode(answer, return_tensors='pt').cuda()[:, 1:] 120 | past_key_values = None 121 | 122 | num_of_token = 0 123 | output_ids = model.generate( 124 | input_ids, 125 | images=image_tensor, 126 | do_sample=True if (args.temperature > 0 and args.ratio == 0) else False, 127 | temperature=args.temperature if args.ratio == 0 else 0, 128 | top_p=args.top_p, 129 | num_beams=args.num_beams, 130 | max_new_tokens=1024, 131 | use_cache=True, 132 | stopping_criteria=[stopping_criteria], 133 | kv_cache_criteria=kv_cache) 134 | 135 | outputs_generate = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 136 | rouge = Rouge() 137 | scores = rouge.get_scores(outputs_generate, answer) 138 | score_all.append(scores[0]['rouge-l']['f']) 139 | 140 | rouge = sum(score_all) / len(score_all) 141 | with open(f"logs_rouge_llava/{args.exp_name}.txt", "a") as f: 142 | f.write(f"{rouge}\n") 143 | 144 | if __name__ == "__main__": 145 | parser = argparse.ArgumentParser() 146 | parser.add_argument("--model-path", type=str, default="./models/llava-v1.5-7b") 147 | parser.add_argument("--model-base", type=str, default=None) 148 | parser.add_argument("--data-path", type=str, default="./playground/data/mm-vet/mm-vet.json") 149 | parser.add_argument("--image-path", type=str, default="./playground/data/mm-vet/images") 150 | parser.add_argument("--device", type=str, default="cuda") 151 | parser.add_argument("--conv-mode", type=str, default=None) 152 | parser.add_argument("--temperature", type=float, default=0.2) 153 | parser.add_argument("--max-new-tokens", type=int, default=512) 154 | parser.add_argument("--num-chunks", type=int, default=1) 155 | parser.add_argument("--chunk-idx", type=int, default=0) 156 | parser.add_argument("--top_p", type=float, default=None) 157 | parser.add_argument("--num_beams", type=int, default=1) 158 | parser.add_argument("--load-8bit", action="store_true") 159 | parser.add_argument("--load-4bit", action="store_true") 160 | parser.add_argument("--debug", action="store_true") 161 | parser.add_argument("--image-aspect-ratio", type=str, default='pad') 162 | parser.add_argument("--start-size", type=int, default=1) 163 | parser.add_argument("--recent-size", type=int, default=2047) 164 | parser.add_argument("--eval-samples", type=int, default=218) 165 | parser.add_argument("--exp-name", type=str, default='') 166 | parser.add_argument("--method", type=str, default="elastic") 167 | parser.add_argument("--ratio", type=float, default=0.2) 168 | args = parser.parse_args() 169 | main(args) 170 | -------------------------------------------------------------------------------- /eval_rouge_qwen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | import os 4 | from torch.nn import CrossEntropyLoss 5 | from kv_cache_qwen import ElasticCache, LocalCache, H2OCache 6 | import json 7 | device = "cuda" 8 | 9 | import argparse 10 | import torch 11 | 12 | from cache_generate_qwen import generate, sample, greedy_search 13 | import types 14 | 15 | from qwen_generation_utils import make_context 16 | from rouge import Rouge 17 | from PIL import Image 18 | 19 | import requests 20 | from PIL import Image 21 | from io import BytesIO 22 | from transformers import TextStreamer 23 | from transformers import AutoModelForCausalLM, AutoTokenizer 24 | 25 | def load_image(image_file): 26 | if image_file.startswith('http://') or image_file.startswith('https://'): 27 | response = requests.get(image_file) 28 | image = Image.open(BytesIO(response.content)).convert('RGB') 29 | else: 30 | image = Image.open(image_file).convert('RGB') 31 | return image 32 | 33 | def main(args): 34 | 35 | 36 | with open(args.data_path, "r") as f: 37 | data = json.load(f) 38 | 39 | tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) 40 | model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="cuda", trust_remote_code=True, bf16=True).eval() 41 | 42 | parent_class = model.__class__.__bases__[0] 43 | model.generate = types.MethodType(generate, model) 44 | model.sample = types.MethodType(sample, model) 45 | model.greedy_search = types.MethodType(greedy_search, model) 46 | 47 | k_seq_dim = v_seq_dim = 1 48 | 49 | os.makedirs('logs_detail/', exist_ok=True) 50 | 51 | data = data[:args.eval_samples] 52 | score_all = [] 53 | for item in tqdm(data): 54 | if args.method == "elastic": 55 | kv_cache = ElasticCache( 56 | start_size=args.start_size, 57 | recent_size=args.recent_size, 58 | k_seq_dim=k_seq_dim, 59 | v_seq_dim=v_seq_dim, 60 | ratio=args.ratio, 61 | layer_num=32 62 | ) 63 | elif args.method == "local": 64 | kv_cache = LocalCache( 65 | start_size=args.start_size, 66 | recent_size=args.recent_size, 67 | k_seq_dim=k_seq_dim, 68 | v_seq_dim=v_seq_dim, 69 | ratio=args.ratio 70 | ) 71 | elif args.method == "h2o": 72 | kv_cache = H2OCache( 73 | start_size=args.start_size, 74 | recent_size=args.recent_size, 75 | k_seq_dim=k_seq_dim, 76 | v_seq_dim=v_seq_dim, 77 | ratio=args.ratio 78 | ) 79 | 80 | image_path = os.path.join(args.image_path, item["image"]) 81 | question = item['question'] 82 | answer = item['answer'] 83 | 84 | question = question.replace('', '') 85 | 86 | query = tokenizer.from_list_format([ 87 | {'image': image_path}, 88 | {'text': question} 89 | ]) 90 | 91 | raw_text, context_tokens = make_context( 92 | tokenizer, 93 | query, 94 | history=None, 95 | system="You are a helpful assistant.", 96 | max_window_size=None, 97 | chat_format='chatml', 98 | ) 99 | 100 | input_ids = torch.tensor([context_tokens]).cuda() 101 | answer_ids = tokenizer.encode(answer, return_tensors='pt').cuda()[:, 1:] 102 | past_key_values = None 103 | 104 | num_of_token = 0 105 | output_ids = model.generate( 106 | input_ids, 107 | do_sample=True if (args.temperature > 0 and args.ratio == 0) else False, 108 | temperature=args.temperature if args.ratio == 0 else 0, 109 | top_p=args.top_p, 110 | num_beams=args.num_beams, 111 | max_new_tokens=1024, 112 | use_cache=True, 113 | kv_cache_criteria=kv_cache, 114 | attention_mask=None) 115 | 116 | outputs_generate = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 117 | rouge = Rouge() 118 | scores = rouge.get_scores(outputs_generate, answer) 119 | score_all.append(scores[0]['rouge-l']['f']) 120 | 121 | rouge = sum(score_all) / len(score_all) 122 | with open(f"logs_rouge_qwen/{args.exp_name}.txt", "a") as f: 123 | f.write(f"{rouge}\n") 124 | f.write("\n") 125 | 126 | if __name__ == "__main__": 127 | parser = argparse.ArgumentParser() 128 | parser.add_argument("--model-path", type=str, default="./models/qwen_vl_chat/") 129 | parser.add_argument("--model-base", type=str, default=None) 130 | parser.add_argument("--data-path", type=str, default="./playground/data/mm-vet/rouge-qwen-detail.json") 131 | parser.add_argument("--image-path", type=str, default="./playground/data") 132 | parser.add_argument("--device", type=str, default="cuda") 133 | parser.add_argument("--conv-mode", type=str, default=None) 134 | parser.add_argument("--temperature", type=float, default=0.2) 135 | parser.add_argument("--max-new-tokens", type=int, default=512) 136 | parser.add_argument("--num-chunks", type=int, default=1) 137 | parser.add_argument("--chunk-idx", type=int, default=0) 138 | parser.add_argument("--top_p", type=float, default=None) 139 | parser.add_argument("--num_beams", type=int, default=1) 140 | parser.add_argument("--load-8bit", action="store_true") 141 | parser.add_argument("--load-4bit", action="store_true") 142 | parser.add_argument("--debug", action="store_true") 143 | parser.add_argument("--image-aspect-ratio", type=str, default='pad') 144 | parser.add_argument("--start-size", type=int, default=1) 145 | parser.add_argument("--recent-size", type=int, default=2047) 146 | parser.add_argument("--eval-samples", type=int, default=218) 147 | parser.add_argument("--exp-name", type=str, default='') 148 | parser.add_argument("--method", type=str, default="elastic") 149 | parser.add_argument("--ratio", type=float, default=0.0) 150 | args = parser.parse_args() 151 | main(args) 152 | -------------------------------------------------------------------------------- /kv_cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import numpy as np 4 | def slice2d(x, start, end): 5 | return x[:, :, start:end, ...] 6 | 7 | 8 | def slice3d(x, start, end): 9 | return x[:, :, :, start:end, ...] 10 | 11 | 12 | def slice1d(x, start, end): 13 | return x[:, start:end, ...] 14 | 15 | import torch.nn.functional as F 16 | DIM_TO_SLICE = { 17 | 1: slice1d, 18 | 2: slice2d, 19 | 3: slice3d, 20 | } 21 | 22 | class ElasticCache: 23 | def __init__( 24 | self, 25 | start_size=4, 26 | recent_size=512, 27 | k_seq_dim=2, 28 | v_seq_dim=2, 29 | ratio=0., 30 | distance=-25, 31 | layer_num=40, 32 | ): 33 | self.start_size = start_size 34 | self.recent_size = recent_size 35 | self.cache_size = start_size + recent_size 36 | self.k_seq_dim = k_seq_dim 37 | self.v_seq_dim = v_seq_dim 38 | self.k_slice = DIM_TO_SLICE[k_seq_dim] 39 | self.v_slice = DIM_TO_SLICE[v_seq_dim] 40 | 41 | self.score_sum = torch.zeros(layer_num, self.cache_size + 1).cuda() 42 | self.ratio = ratio 43 | self.protect_size = 1 44 | self.flag = True 45 | self.distance = distance 46 | self.layer_num = layer_num 47 | 48 | self.selected_idx = 0 49 | 50 | def __call__(self, past_key_values, num_of_token=None, attentions=None): 51 | if past_key_values is None: 52 | return None 53 | attn_score = [attention for attention in attentions] 54 | seq_len = past_key_values[0][0].size(self.k_seq_dim) 55 | 56 | # update attn score 57 | attn_score = torch.cat(attn_score, dim=0) 58 | attn_score = attn_score.mean(dim=1, keepdim=False) 59 | if attn_score.shape[-2] > 1: 60 | assert self.flag is True # only use for the first time 61 | for idx in range(attn_score.shape[-1]): 62 | cur_score = attn_score[:, idx, :idx+1] 63 | self.score_sum[:, :(cur_score.shape[-1])] += cur_score 64 | else: 65 | pass 66 | 67 | forget_num = int(seq_len - num_of_token * (1 - self.ratio)) 68 | if forget_num <= 0: 69 | return past_key_values 70 | else: 71 | if forget_num > 1: 72 | assert self.flag is True 73 | self.flag = False 74 | 75 | selected_idx_all = [] 76 | merge_idx_all = [] 77 | throw_idx_all = [] 78 | for idx in range(self.layer_num): 79 | selected_idx = torch.where(torch.argsort(self.score_sum[idx, self.start_size:(seq_len - self.protect_size)]) > forget_num)[0] + self.start_size 80 | throw_idx = torch.where(torch.argsort(self.score_sum[idx, self.start_size:(seq_len - self.protect_size)]) <= forget_num)[0] 81 | merge_idx = [] 82 | for i in range(len(throw_idx)): 83 | merge_idx.append(selected_idx[torch.abs((selected_idx - throw_idx[i])).argmin()].unsqueeze(0)) 84 | merge_idx = torch.cat(merge_idx) 85 | 86 | selected_idx = torch.cat([torch.arange(self.start_size).cuda(), selected_idx, torch.tensor([seq_len - self.protect_size]).cuda()], dim=0) # the last token is always kept 87 | 88 | selected_idx_all.append(selected_idx) 89 | merge_idx_all.append(merge_idx) 90 | throw_idx_all.append(throw_idx) 91 | 92 | if self.distance > 0: 93 | self.selected_idx = self.distance 94 | else: 95 | self.selected_idx = seq_len - forget_num + self.distance 96 | 97 | past_key_values_return = [] 98 | for idx, (k, v) in enumerate(past_key_values): 99 | selected_idx = selected_idx_all[idx] 100 | merge_idx = merge_idx_all[idx] 101 | throw_idx = throw_idx_all[idx] 102 | 103 | k_forget = k.gather(dim=-2, index=throw_idx.view(1,1,-1,1).expand(k.shape[0], k.shape[1], -1 ,k.shape[-1])) 104 | v_forget = v.gather(dim=-2, index=throw_idx.view(1,1,-1,1).expand(v.shape[0], v.shape[1], -1 ,v.shape[-1])) 105 | 106 | k = k.scatter_reduce(-2, merge_idx.view(1,1,-1,1).expand(k.shape[0], k.shape[1], -1 ,k.shape[-1]), k_forget, 'mean') 107 | v = v.scatter_reduce(-2, merge_idx.view(1,1,-1,1).expand(v.shape[0], v.shape[1], -1 ,v.shape[-1]), v_forget, 'mean') 108 | 109 | k_new = k.gather(dim=-2, index=selected_idx.view(1,1,-1,1).expand(k.shape[0], k.shape[1], -1 ,k.shape[-1])) 110 | v_new = v.gather(dim=-2, index=selected_idx.view(1,1,-1,1).expand(v.shape[0], v.shape[1], -1 ,v.shape[-1])) 111 | 112 | past_key_values_return.append([k_new, v_new]) 113 | return past_key_values_return 114 | else: 115 | selected_idx = self.selected_idx 116 | return [[torch.cat([self.k_slice(k, 0, selected_idx), self.k_slice(k, (selected_idx+1), seq_len),], 117 | dim=self.k_seq_dim,), 118 | torch.cat([self.v_slice(v, 0, selected_idx), self.v_slice(v, (selected_idx+1), seq_len),], 119 | dim=self.v_seq_dim,)] 120 | for k, v in past_key_values] 121 | 122 | 123 | class LocalCache: 124 | def __init__( 125 | self, 126 | start_size=4, 127 | recent_size=512, 128 | k_seq_dim=2, 129 | v_seq_dim=2, 130 | ratio=0. 131 | ): 132 | self.start_size = start_size 133 | self.recent_size = recent_size 134 | self.cache_size = start_size + recent_size 135 | self.k_seq_dim = k_seq_dim 136 | self.v_seq_dim = v_seq_dim 137 | self.k_slice = DIM_TO_SLICE[k_seq_dim] 138 | self.v_slice = DIM_TO_SLICE[v_seq_dim] 139 | self.ratio = ratio 140 | 141 | def __call__(self, past_key_values, num_of_token=None, attentions=None): 142 | if past_key_values is None: 143 | return None 144 | seq_len = past_key_values[0][0].size(self.k_seq_dim) 145 | 146 | forget_num = int(seq_len - num_of_token * (1 - self.ratio)) 147 | if forget_num <= 0: 148 | return past_key_values 149 | else: 150 | return [[torch.cat([self.k_slice(k, 0, self.start_size), self.k_slice(k, forget_num + self.start_size, seq_len),], 151 | dim=self.k_seq_dim,), 152 | torch.cat([self.v_slice(v, 0, self.start_size), self.v_slice(v, forget_num + self.start_size, seq_len),], 153 | dim=self.v_seq_dim,),] 154 | for k, v in past_key_values] 155 | 156 | 157 | class H2OCache: 158 | def __init__( 159 | self, 160 | start_size=4, 161 | recent_size=512, 162 | k_seq_dim=2, 163 | v_seq_dim=2, 164 | ratio=0. 165 | ): 166 | self.start_size = start_size 167 | self.recent_size = recent_size 168 | self.cache_size = start_size + recent_size 169 | self.k_seq_dim = k_seq_dim 170 | self.v_seq_dim = v_seq_dim 171 | self.k_slice = DIM_TO_SLICE[k_seq_dim] 172 | self.v_slice = DIM_TO_SLICE[v_seq_dim] 173 | 174 | self.score_sum = torch.zeros(self.cache_size + 1).cuda() 175 | self.ratio = ratio 176 | self.protect_size = 1 177 | self.flag = True 178 | 179 | def __call__(self, past_key_values, num_of_token=None, attentions=None): 180 | if past_key_values is None: 181 | return None 182 | attn_score = [attention for attention in attentions] 183 | past_key_values_new = tuple(x for x in past_key_values) 184 | seq_len = past_key_values_new[0][0].size(self.k_seq_dim) 185 | # update attn score 186 | attn_score = torch.cat(attn_score, dim=0) 187 | attn_score = attn_score.mean(dim=1, keepdim=False).mean(dim=0, keepdim=False) 188 | 189 | if attn_score.shape[-2] > 1: 190 | assert self.flag is True # only use for the first time 191 | for idx in range(attn_score.shape[-1]): 192 | cur_score = attn_score[idx][:idx+1] 193 | self.score_sum[:len(cur_score)] += cur_score 194 | else: 195 | attn_score = attn_score.squeeze(0) 196 | self.score_sum[:seq_len] += attn_score 197 | 198 | forget_num = int(seq_len - num_of_token * (1 - self.ratio)) 199 | self.protect_size = 1 200 | if forget_num <= 0: 201 | return past_key_values_new 202 | else: 203 | if forget_num > 1: 204 | assert self.flag is True 205 | self.flag = False 206 | selected_idx = torch.where(torch.argsort(self.score_sum[:(seq_len - self.protect_size)]) > forget_num)[0] 207 | selected_idx = torch.cat([selected_idx, torch.arange(seq_len - self.protect_size, seq_len).cuda()], dim=0) 208 | past_key_values_return = [] 209 | for k, v in past_key_values_new: 210 | k_new = k.gather(dim=-2, index=selected_idx.view(1,1,-1,1).expand(k.shape[0], k.shape[1], -1 ,k.shape[-1])) 211 | v_new = v.gather(dim=-2, index=selected_idx.view(1,1,-1,1).expand(v.shape[0], v.shape[1], -1 ,v.shape[-1])) 212 | past_key_values_return.append([k_new, v_new]) 213 | 214 | return past_key_values_return 215 | else: 216 | selected_idx = self.score_sum[self.start_size:(seq_len - self.protect_size)].argmin() + self.start_size 217 | self.score_sum[(selected_idx):-1] = self.score_sum[(selected_idx+1):].clone() 218 | 219 | return [[torch.cat([self.k_slice(k, 0, selected_idx), self.k_slice(k, (selected_idx+1), seq_len),], 220 | dim=self.k_seq_dim,), 221 | torch.cat([self.v_slice(v, 0, selected_idx), self.v_slice(v, (selected_idx+1), seq_len),], 222 | dim=self.v_seq_dim,)] 223 | for k, v in past_key_values_new] -------------------------------------------------------------------------------- /kv_cache_qwen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | 4 | def slice2d(x, start, end): 5 | return x[:, :, start:end, ...] 6 | 7 | 8 | def slice3d(x, start, end): 9 | return x[:, :, :, start:end, ...] 10 | 11 | 12 | def slice1d(x, start, end): 13 | return x[:, start:end, ...] 14 | 15 | import torch.nn.functional as F 16 | DIM_TO_SLICE = { 17 | 1: slice1d, 18 | 2: slice2d, 19 | 3: slice3d, 20 | } 21 | 22 | class ElasticCache: 23 | def __init__( 24 | self, 25 | start_size=4, 26 | recent_size=512, 27 | k_seq_dim=2, 28 | v_seq_dim=2, 29 | ratio=0., 30 | distance=-25, 31 | layer_num=40, 32 | ): 33 | self.start_size = start_size 34 | self.recent_size = recent_size 35 | self.cache_size = start_size + recent_size 36 | self.k_seq_dim = k_seq_dim 37 | self.v_seq_dim = v_seq_dim 38 | self.k_slice = DIM_TO_SLICE[k_seq_dim] 39 | self.v_slice = DIM_TO_SLICE[v_seq_dim] 40 | 41 | self.score_sum = torch.zeros(layer_num, self.cache_size + 1).cuda() 42 | self.ratio = ratio 43 | self.protect_size = 1 44 | self.lazy_size = 0 45 | self.flag = True 46 | self.distance = distance 47 | self.layer_num = layer_num 48 | 49 | self.selected_idx = 0 50 | 51 | def __call__(self, past_key_values, num_of_token=None, attentions=None): 52 | if past_key_values is None: 53 | return None 54 | attn_score = [attention for attention in attentions] 55 | seq_len = past_key_values[0][0].size(self.k_seq_dim) 56 | 57 | # update attn score 58 | attn_score = torch.cat(attn_score, dim=0) 59 | attn_score = attn_score.mean(dim=1, keepdim=False) 60 | 61 | if attn_score.shape[-2] > 1: 62 | assert self.flag is True # only use for the first time 63 | for idx in range(attn_score.shape[-1]): 64 | cur_score = attn_score[:, idx, :idx+1] 65 | self.score_sum[:, :(cur_score.shape[-1])] += cur_score 66 | else: 67 | pass 68 | 69 | forget_num = int(seq_len - num_of_token * (1 - self.ratio)) 70 | if forget_num <= 0: 71 | return past_key_values 72 | else: 73 | if forget_num > 1: 74 | assert self.flag is True 75 | self.flag = False 76 | 77 | selected_idx_all = [] 78 | merge_idx_all = [] 79 | throw_idx_all = [] 80 | for idx in range(self.layer_num): 81 | selected_idx = torch.where(torch.argsort(self.score_sum[idx, self.start_size:(seq_len - self.protect_size)]) > forget_num)[0] + self.start_size 82 | throw_idx = torch.where(torch.argsort(self.score_sum[idx, self.start_size:(seq_len - self.protect_size)]) <= forget_num)[0] 83 | merge_idx = [] 84 | for i in range(len(throw_idx)): 85 | merge_idx.append(selected_idx[torch.abs((selected_idx - throw_idx[i])).argmin()].unsqueeze(0)) 86 | merge_idx = torch.cat(merge_idx) 87 | 88 | selected_idx = torch.cat([torch.arange(self.start_size).cuda(), selected_idx, torch.tensor([seq_len - self.protect_size]).cuda()], dim=0) # the last token is always kept 89 | 90 | selected_idx_all.append(selected_idx) 91 | merge_idx_all.append(merge_idx) 92 | throw_idx_all.append(throw_idx) 93 | 94 | if self.distance > 0: 95 | self.selected_idx = self.distance 96 | else: 97 | self.selected_idx = seq_len - forget_num + self.distance 98 | 99 | past_key_values_return = [] 100 | for idx, (k, v) in enumerate(past_key_values): 101 | selected_idx = selected_idx_all[idx] 102 | merge_idx = merge_idx_all[idx] 103 | throw_idx = throw_idx_all[idx] 104 | 105 | k_forget = k.gather(dim=1, index=throw_idx.view(1,-1,1,1).expand(k.shape[0], -1, k.shape[2], k.shape[-1])) 106 | v_forget = v.gather(dim=1, index=throw_idx.view(1,-1,1,1).expand(v.shape[0], -1, v.shape[2], v.shape[-1])) 107 | 108 | k = k.scatter_reduce(1, merge_idx.view(1,-1,1,1).expand(k.shape[0], -1, k.shape[2], k.shape[-1]), k_forget, 'mean') 109 | v = v.scatter_reduce(1, merge_idx.view(1,-1,1,1).expand(v.shape[0], -1, v.shape[2], v.shape[-1]), v_forget, 'mean') 110 | 111 | k_new = k.gather(dim=1, index=selected_idx.view(1,-1,1,1).expand(k.shape[0], -1, k.shape[2] ,k.shape[-1])) 112 | v_new = v.gather(dim=1, index=selected_idx.view(1,-1,1,1).expand(v.shape[0], -1, v.shape[2] ,v.shape[-1])) 113 | 114 | past_key_values_return.append([k_new, v_new]) 115 | return past_key_values_return 116 | else: 117 | selected_idx = self.selected_idx 118 | 119 | return [[torch.cat([self.k_slice(k, 0, selected_idx), self.k_slice(k, (selected_idx+1), seq_len),], 120 | dim=self.k_seq_dim,), 121 | torch.cat([self.v_slice(v, 0, selected_idx), self.v_slice(v, (selected_idx+1), seq_len),], 122 | dim=self.v_seq_dim,)] 123 | for k, v in past_key_values] 124 | 125 | 126 | class LocalCache: 127 | def __init__( 128 | self, 129 | start_size=4, 130 | recent_size=512, 131 | k_seq_dim=2, 132 | v_seq_dim=2, 133 | ratio=0. 134 | ): 135 | self.start_size = start_size 136 | self.recent_size = recent_size 137 | self.cache_size = start_size + recent_size 138 | self.k_seq_dim = k_seq_dim 139 | self.v_seq_dim = v_seq_dim 140 | self.k_slice = DIM_TO_SLICE[k_seq_dim] 141 | self.v_slice = DIM_TO_SLICE[v_seq_dim] 142 | self.ratio = ratio 143 | 144 | def __call__(self, past_key_values, num_of_token=None, attentions=None): 145 | if past_key_values is None: 146 | return None 147 | seq_len = past_key_values[0][0].size(self.k_seq_dim) 148 | 149 | forget_num = int(seq_len - num_of_token * (1 - self.ratio)) 150 | if forget_num <= 0: 151 | return past_key_values 152 | else: 153 | return [[torch.cat([self.k_slice(k, 0, self.start_size), self.k_slice(k, forget_num + self.start_size, seq_len),], 154 | dim=self.k_seq_dim,), 155 | torch.cat([self.v_slice(v, 0, self.start_size), self.v_slice(v, forget_num + self.start_size, seq_len),], 156 | dim=self.v_seq_dim,),] 157 | for k, v in past_key_values] 158 | 159 | 160 | class H2OCache: 161 | def __init__( 162 | self, 163 | start_size=4, 164 | recent_size=512, 165 | k_seq_dim=2, 166 | v_seq_dim=2, 167 | ratio=0. 168 | ): 169 | self.start_size = start_size 170 | self.recent_size = recent_size 171 | self.cache_size = start_size + recent_size 172 | self.k_seq_dim = k_seq_dim 173 | self.v_seq_dim = v_seq_dim 174 | self.k_slice = DIM_TO_SLICE[k_seq_dim] 175 | self.v_slice = DIM_TO_SLICE[v_seq_dim] 176 | 177 | self.score_sum = torch.zeros(self.cache_size + 1).cuda() 178 | self.ratio = ratio 179 | self.protect_size = 1 180 | self.flag = True 181 | 182 | def __call__(self, past_key_values, num_of_token=None, attentions=None): 183 | if past_key_values is None: 184 | return None 185 | attn_score = [attention for attention in attentions] 186 | past_key_values_new = tuple(x for x in past_key_values) 187 | seq_len = past_key_values_new[0][0].size(self.k_seq_dim) 188 | # update attn score 189 | attn_score = torch.cat(attn_score, dim=0) 190 | attn_score = attn_score.mean(dim=1, keepdim=False).mean(dim=0, keepdim=False) 191 | 192 | if attn_score.shape[-2] > 1: 193 | assert self.flag is True # only use for the first time 194 | for idx in range(attn_score.shape[-1]): 195 | cur_score = attn_score[idx][:idx+1] 196 | self.score_sum[:len(cur_score)] += cur_score 197 | else: 198 | attn_score = attn_score.squeeze(0) 199 | self.score_sum[:seq_len] += attn_score 200 | 201 | forget_num = int(seq_len - num_of_token * (1 - self.ratio)) 202 | if forget_num <= 0: 203 | return past_key_values_new 204 | else: 205 | if forget_num > 1: 206 | assert self.flag is True 207 | self.flag = False 208 | selected_idx = torch.where(torch.argsort(self.score_sum[:(seq_len - self.protect_size)]) > forget_num)[0] 209 | selected_idx = torch.cat([selected_idx, torch.arange(seq_len - self.protect_size, seq_len).cuda()], dim=0) 210 | past_key_values_return = [] 211 | for k, v in past_key_values_new: 212 | k_new = k.gather(dim=1, index=selected_idx.view(1,-1,1,1).expand(k.shape[0], -1, k.shape[2], k.shape[-1])) 213 | v_new = v.gather(dim=1, index=selected_idx.view(1,-1,1,1).expand(v.shape[0], -1, v.shape[2], v.shape[-1])) 214 | past_key_values_return.append([k_new, v_new]) 215 | 216 | return past_key_values_return 217 | else: 218 | selected_idx = self.score_sum[self.start_size:(seq_len - self.protect_size)].argmin() + self.start_size 219 | self.score_sum[(selected_idx):-1] = self.score_sum[(selected_idx+1):].clone() 220 | 221 | return [[torch.cat([self.k_slice(k, 0, selected_idx), self.k_slice(k, (selected_idx+1), seq_len),], 222 | dim=self.k_seq_dim,), 223 | torch.cat([self.v_slice(v, 0, selected_idx), self.v_slice(v, (selected_idx+1), seq_len),], 224 | dim=self.v_seq_dim,)] 225 | for k, v in past_key_values_new] -------------------------------------------------------------------------------- /llava/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import LlavaLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /llava/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | -------------------------------------------------------------------------------- /llava/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | 5 | import torch 6 | from transformers import StoppingCriteria 7 | from llava.constants import IMAGE_TOKEN_INDEX 8 | 9 | 10 | def load_image_from_base64(image): 11 | return Image.open(BytesIO(base64.b64decode(image))) 12 | 13 | 14 | def expand2square(pil_img, background_color): 15 | width, height = pil_img.size 16 | if width == height: 17 | return pil_img 18 | elif width > height: 19 | result = Image.new(pil_img.mode, (width, width), background_color) 20 | result.paste(pil_img, (0, (width - height) // 2)) 21 | return result 22 | else: 23 | result = Image.new(pil_img.mode, (height, height), background_color) 24 | result.paste(pil_img, ((height - width) // 2, 0)) 25 | return result 26 | 27 | 28 | def process_images(images, image_processor, model_cfg): 29 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 30 | new_images = [] 31 | if image_aspect_ratio == 'pad': 32 | for image in images: 33 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 34 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 35 | new_images.append(image) 36 | else: 37 | return image_processor(images, return_tensors='pt')['pixel_values'] 38 | if all(x.shape == new_images[0].shape for x in new_images): 39 | new_images = torch.stack(new_images, dim=0) 40 | return new_images 41 | 42 | 43 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 44 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 45 | 46 | def insert_separator(X, sep): 47 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 48 | 49 | input_ids = [] 50 | offset = 0 51 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 52 | offset = 1 53 | input_ids.append(prompt_chunks[0][0]) 54 | 55 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 56 | input_ids.extend(x[offset:]) 57 | 58 | if return_tensors is not None: 59 | if return_tensors == 'pt': 60 | return torch.tensor(input_ids, dtype=torch.long) 61 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 62 | return input_ids 63 | 64 | 65 | def get_model_name_from_path(model_path): 66 | model_path = model_path.strip("/") 67 | model_paths = model_path.split("/") 68 | if model_paths[-1].startswith('checkpoint-'): 69 | return model_paths[-2] + "_" + model_paths[-1] 70 | else: 71 | return model_paths[-1] 72 | 73 | 74 | 75 | 76 | class KeywordsStoppingCriteria(StoppingCriteria): 77 | def __init__(self, keywords, tokenizer, input_ids): 78 | self.keywords = keywords 79 | self.keyword_ids = [] 80 | self.max_keyword_len = 0 81 | for keyword in keywords: 82 | cur_keyword_ids = tokenizer(keyword).input_ids 83 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 84 | cur_keyword_ids = cur_keyword_ids[1:] 85 | if len(cur_keyword_ids) > self.max_keyword_len: 86 | self.max_keyword_len = len(cur_keyword_ids) 87 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 88 | self.tokenizer = tokenizer 89 | self.start_len = input_ids.shape[1] 90 | 91 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 92 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 93 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 94 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 95 | for keyword_id in self.keyword_ids: 96 | if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): 97 | return True 98 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 99 | for keyword in self.keywords: 100 | if keyword in outputs: 101 | return True 102 | return False -------------------------------------------------------------------------------- /llava/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.llava_llama import LlavaLlamaForCausalLM, LlavaConfig 2 | from .language_model.llava_mpt import LlavaMPTForCausalLM, LlavaMPTConfig 3 | -------------------------------------------------------------------------------- /llava/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava import LlavaLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = LlavaLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /llava/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | import torch 22 | from llava.model import * 23 | from llava.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | 25 | 26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda"): 27 | kwargs = {"device_map": device_map} 28 | 29 | if device != "cuda": 30 | kwargs['device_map'] = {"": device} 31 | 32 | if load_8bit: 33 | kwargs['load_in_8bit'] = True 34 | elif load_4bit: 35 | kwargs['load_in_4bit'] = True 36 | kwargs['quantization_config'] = BitsAndBytesConfig( 37 | load_in_4bit=True, 38 | bnb_4bit_compute_dtype=torch.float16, 39 | bnb_4bit_use_double_quant=True, 40 | bnb_4bit_quant_type='nf4' 41 | ) 42 | else: 43 | # kwargs['torch_dtype'] = torch.float16 44 | kwargs['torch_dtype'] = torch.bfloat16 45 | 46 | if 'llava' in model_name.lower(): 47 | # Load LLaVA model 48 | if 'lora' in model_name.lower() and model_base is None: 49 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') 50 | if 'lora' in model_name.lower() and model_base is not None: 51 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) 52 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 53 | print('Loading LLaVA from base model...') 54 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 55 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 56 | if model.lm_head.weight.shape[0] != token_num: 57 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 58 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 59 | 60 | print('Loading additional LLaVA weights...') 61 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 62 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 63 | else: 64 | # this is probably from HF Hub 65 | from huggingface_hub import hf_hub_download 66 | def load_from_hf(repo_id, filename, subfolder=None): 67 | cache_file = hf_hub_download( 68 | repo_id=repo_id, 69 | filename=filename, 70 | subfolder=subfolder) 71 | return torch.load(cache_file, map_location='cpu') 72 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 73 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 74 | if any(k.startswith('model.model.') for k in non_lora_trainables): 75 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 76 | model.load_state_dict(non_lora_trainables, strict=False) 77 | 78 | from peft import PeftModel 79 | print('Loading LoRA weights...') 80 | model = PeftModel.from_pretrained(model, model_path) 81 | print('Merging LoRA weights...') 82 | model = model.merge_and_unload() 83 | print('Model is loaded...') 84 | elif model_base is not None: 85 | # this may be mm projector only 86 | print('Loading LLaVA from base model...') 87 | if 'mpt' in model_name.lower(): 88 | if not os.path.isfile(os.path.join(model_path, 'configuration_mpt.py')): 89 | shutil.copyfile(os.path.join(model_base, 'configuration_mpt.py'), os.path.join(model_path, 'configuration_mpt.py')) 90 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=True) 91 | cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 92 | model = LlavaMPTForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 93 | else: 94 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 95 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 96 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 97 | 98 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 99 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 100 | model.load_state_dict(mm_projector_weights, strict=False) 101 | else: 102 | if 'mpt' in model_name.lower(): 103 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 104 | model = LlavaMPTForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 105 | else: 106 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 107 | model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 108 | else: 109 | # Load language model 110 | if model_base is not None: 111 | # PEFT model 112 | from peft import PeftModel 113 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 114 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 115 | print(f"Loading LoRA weights from {model_path}") 116 | model = PeftModel.from_pretrained(model, model_path) 117 | print(f"Merging weights") 118 | model = model.merge_and_unload() 119 | print('Convert to FP16...') 120 | model.to(torch.float16) 121 | else: 122 | use_fast = False 123 | if 'mpt' in model_name.lower(): 124 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True) 125 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, trust_remote_code=True, **kwargs) 126 | else: 127 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 128 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 129 | 130 | image_processor = None 131 | 132 | if 'llava' in model_name.lower(): 133 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 134 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 135 | if mm_use_im_patch_token: 136 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 137 | if mm_use_im_start_end: 138 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 139 | model.resize_token_embeddings(len(tokenizer)) 140 | 141 | vision_tower = model.get_vision_tower() 142 | if not vision_tower.is_loaded: 143 | vision_tower.load_model() 144 | vision_tower.to(device=device, dtype=torch.float16) 145 | image_processor = vision_tower.image_processor 146 | 147 | if hasattr(model.config, "max_sequence_length"): 148 | context_len = model.config.max_sequence_length 149 | else: 150 | context_len = 2048 151 | 152 | return tokenizer, model, image_processor, context_len 153 | -------------------------------------------------------------------------------- /llava/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from llava.model import * 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_llama.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | LlamaConfig, LlamaModel, LlamaForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | 27 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaConfig(LlamaConfig): 31 | model_type = "llava" 32 | 33 | 34 | class LlavaLlamaModel(LlavaMetaModel, LlamaModel): 35 | config_class = LlavaConfig 36 | 37 | def __init__(self, config: LlamaConfig): 38 | super(LlavaLlamaModel, self).__init__(config) 39 | 40 | 41 | class LlavaLlamaForCausalLM(LlamaForCausalLM, LlavaMetaForCausalLM): 42 | config_class = LlavaConfig 43 | 44 | def __init__(self, config): 45 | super(LlamaForCausalLM, self).__init__(config) 46 | self.model = LlavaLlamaModel(config) 47 | 48 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 49 | 50 | # Initialize weights and apply final processing 51 | self.post_init() 52 | 53 | def get_model(self): 54 | return self.model 55 | 56 | def forward( 57 | self, 58 | input_ids: torch.LongTensor = None, 59 | attention_mask: Optional[torch.Tensor] = None, 60 | past_key_values: Optional[List[torch.FloatTensor]] = None, 61 | inputs_embeds: Optional[torch.FloatTensor] = None, 62 | labels: Optional[torch.LongTensor] = None, 63 | use_cache: Optional[bool] = None, 64 | output_attentions: Optional[bool] = None, 65 | output_hidden_states: Optional[bool] = None, 66 | images: Optional[torch.FloatTensor] = None, 67 | return_dict: Optional[bool] = None, 68 | ) -> Union[Tuple, CausalLMOutputWithPast]: 69 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 70 | output_hidden_states = ( 71 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 72 | ) 73 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 74 | 75 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 76 | 77 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 78 | outputs = self.model( 79 | input_ids=input_ids, 80 | attention_mask=attention_mask, 81 | past_key_values=past_key_values, 82 | inputs_embeds=inputs_embeds, 83 | use_cache=use_cache, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict 87 | ) 88 | 89 | hidden_states = outputs[0] 90 | logits = self.lm_head(hidden_states) 91 | 92 | loss = None 93 | if labels is not None: 94 | # Shift so that tokens < n predict n 95 | shift_logits = logits[..., :-1, :].contiguous() 96 | shift_labels = labels[..., 1:].contiguous() 97 | # Flatten the tokens 98 | loss_fct = CrossEntropyLoss() 99 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 100 | shift_labels = shift_labels.view(-1) 101 | # Enable model/pipeline parallelism 102 | shift_labels = shift_labels.to(shift_logits.device) 103 | loss = loss_fct(shift_logits, shift_labels) 104 | 105 | if not return_dict: 106 | output = (logits,) + outputs[1:] 107 | return (loss,) + output if loss is not None else output 108 | 109 | return CausalLMOutputWithPast( 110 | loss=loss, 111 | logits=logits, 112 | past_key_values=outputs.past_key_values, 113 | hidden_states=outputs.hidden_states, 114 | attentions=outputs.attentions, 115 | ) 116 | 117 | def prepare_inputs_for_generation( 118 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 119 | ): 120 | if past_key_values: 121 | input_ids = input_ids[:, -1:] 122 | 123 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 124 | if inputs_embeds is not None and past_key_values is None: 125 | model_inputs = {"inputs_embeds": inputs_embeds} 126 | else: 127 | model_inputs = {"input_ids": input_ids} 128 | 129 | model_inputs.update( 130 | { 131 | "past_key_values": past_key_values, 132 | "use_cache": kwargs.get("use_cache"), 133 | "attention_mask": attention_mask, 134 | "images": kwargs.get("images", None), 135 | } 136 | ) 137 | return model_inputs 138 | 139 | AutoConfig.register("llava", LlavaConfig) 140 | AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) 141 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mistral.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from transformers import AutoConfig, AutoModelForCausalLM, \ 23 | MistralConfig, MistralModel, MistralForCausalLM 24 | 25 | from transformers.modeling_outputs import CausalLMOutputWithPast 26 | from transformers.generation.utils import GenerateOutput 27 | 28 | from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 29 | 30 | 31 | class LlavaMistralConfig(MistralConfig): 32 | model_type = "llava_mistral" 33 | 34 | 35 | class LlavaMistralModel(LlavaMetaModel, MistralModel): 36 | config_class = LlavaMistralConfig 37 | 38 | def __init__(self, config: MistralConfig): 39 | super(LlavaMistralModel, self).__init__(config) 40 | 41 | 42 | class LlavaMistralForCausalLM(MistralForCausalLM, LlavaMetaForCausalLM): 43 | config_class = LlavaMistralConfig 44 | 45 | def __init__(self, config): 46 | super(MistralForCausalLM, self).__init__(config) 47 | self.model = LlavaMistralModel(config) 48 | 49 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 50 | 51 | # Initialize weights and apply final processing 52 | self.post_init() 53 | 54 | def get_model(self): 55 | return self.model 56 | 57 | def forward( 58 | self, 59 | input_ids: torch.LongTensor = None, 60 | attention_mask: Optional[torch.Tensor] = None, 61 | position_ids: Optional[torch.LongTensor] = None, 62 | past_key_values: Optional[List[torch.FloatTensor]] = None, 63 | inputs_embeds: Optional[torch.FloatTensor] = None, 64 | labels: Optional[torch.LongTensor] = None, 65 | use_cache: Optional[bool] = None, 66 | output_attentions: Optional[bool] = None, 67 | output_hidden_states: Optional[bool] = None, 68 | images: Optional[torch.FloatTensor] = None, 69 | image_sizes: Optional[List[List[int]]] = None, 70 | return_dict: Optional[bool] = None, 71 | ) -> Union[Tuple, CausalLMOutputWithPast]: 72 | 73 | if inputs_embeds is None: 74 | ( 75 | input_ids, 76 | position_ids, 77 | attention_mask, 78 | past_key_values, 79 | inputs_embeds, 80 | labels 81 | ) = self.prepare_inputs_labels_for_multimodal( 82 | input_ids, 83 | position_ids, 84 | attention_mask, 85 | past_key_values, 86 | labels, 87 | images, 88 | image_sizes 89 | ) 90 | 91 | return super().forward( 92 | input_ids=input_ids, 93 | attention_mask=attention_mask, 94 | position_ids=position_ids, 95 | past_key_values=past_key_values, 96 | inputs_embeds=inputs_embeds, 97 | labels=labels, 98 | use_cache=use_cache, 99 | output_attentions=output_attentions, 100 | output_hidden_states=output_hidden_states, 101 | return_dict=return_dict 102 | ) 103 | 104 | @torch.no_grad() 105 | def generate( 106 | self, 107 | inputs: Optional[torch.Tensor] = None, 108 | images: Optional[torch.Tensor] = None, 109 | image_sizes: Optional[torch.Tensor] = None, 110 | **kwargs, 111 | ) -> Union[GenerateOutput, torch.LongTensor]: 112 | position_ids = kwargs.pop("position_ids", None) 113 | attention_mask = kwargs.pop("attention_mask", None) 114 | if "inputs_embeds" in kwargs: 115 | raise NotImplementedError("`inputs_embeds` is not supported") 116 | 117 | if images is not None: 118 | ( 119 | inputs, 120 | position_ids, 121 | attention_mask, 122 | _, 123 | inputs_embeds, 124 | _ 125 | ) = self.prepare_inputs_labels_for_multimodal( 126 | inputs, 127 | position_ids, 128 | attention_mask, 129 | None, 130 | None, 131 | images, 132 | image_sizes=image_sizes 133 | ) 134 | else: 135 | inputs_embeds = self.get_model().embed_tokens(inputs) 136 | 137 | return super().generate( 138 | position_ids=position_ids, 139 | attention_mask=attention_mask, 140 | inputs_embeds=inputs_embeds, 141 | **kwargs 142 | ) 143 | 144 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 145 | inputs_embeds=None, **kwargs): 146 | images = kwargs.pop("images", None) 147 | image_sizes = kwargs.pop("image_sizes", None) 148 | inputs = super().prepare_inputs_for_generation( 149 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 150 | ) 151 | if images is not None: 152 | inputs['images'] = images 153 | if image_sizes is not None: 154 | inputs['image_sizes'] = image_sizes 155 | return inputs 156 | 157 | AutoConfig.register("llava_mistral", LlavaMistralConfig) 158 | AutoModelForCausalLM.register(LlavaMistralConfig, LlavaMistralForCausalLM) 159 | -------------------------------------------------------------------------------- /llava/model/language_model/llava_mpt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple 17 | import warnings 18 | 19 | import torch 20 | import torch.nn.functional as F 21 | import math 22 | 23 | from transformers import AutoConfig, AutoModelForCausalLM 24 | from transformers.modeling_outputs import CausalLMOutputWithPast 25 | 26 | from .mpt.modeling_mpt import MPTConfig, MPTForCausalLM, MPTModel 27 | from llava.model.llava_arch import LlavaMetaModel, LlavaMetaForCausalLM 28 | 29 | 30 | class LlavaMPTConfig(MPTConfig): 31 | model_type = "llava_mpt" 32 | 33 | 34 | class LlavaMPTModel(LlavaMetaModel, MPTModel): 35 | config_class = LlavaMPTConfig 36 | 37 | def __init__(self, config: MPTConfig): 38 | config.hidden_size = config.d_model 39 | super(LlavaMPTModel, self).__init__(config) 40 | 41 | def embed_tokens(self, x): 42 | return self.wte(x) 43 | 44 | 45 | class LlavaMPTForCausalLM(MPTForCausalLM, LlavaMetaForCausalLM): 46 | config_class = LlavaMPTConfig 47 | supports_gradient_checkpointing = True 48 | 49 | def __init__(self, config): 50 | super(MPTForCausalLM, self).__init__(config) 51 | 52 | if not config.tie_word_embeddings: 53 | raise ValueError('MPTForCausalLM only supports tied word embeddings') 54 | self.transformer = LlavaMPTModel(config) 55 | self.logit_scale = None 56 | if config.logit_scale is not None: 57 | logit_scale = config.logit_scale 58 | if isinstance(logit_scale, str): 59 | if logit_scale == 'inv_sqrt_d_model': 60 | logit_scale = 1 / math.sqrt(config.d_model) 61 | else: 62 | raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 63 | self.logit_scale = logit_scale 64 | 65 | def get_model(self): 66 | return self.transformer 67 | 68 | def _set_gradient_checkpointing(self, module, value=False): 69 | if isinstance(module, LlavaMPTModel): 70 | module.gradient_checkpointing = value 71 | 72 | def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, images=None): 73 | return_dict = return_dict if return_dict is not None else self.config.return_dict 74 | use_cache = use_cache if use_cache is not None else self.config.use_cache 75 | 76 | input_ids, attention_mask, past_key_values, inputs_embeds, labels = self.prepare_inputs_labels_for_multimodal(input_ids, attention_mask, past_key_values, labels, images) 77 | outputs = self.transformer(input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache) 78 | # FIXME: this is a hack to fix the multiple gpu inference issue in https://github.com/haotian-liu/LLaVA/issues/338 79 | logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight) 80 | if self.logit_scale is not None: 81 | if self.logit_scale == 0: 82 | warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.') 83 | logits *= self.logit_scale 84 | loss = None 85 | if labels is not None: 86 | labels = torch.roll(labels, shifts=-1) 87 | labels[:, -1] = -100 88 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1)) 89 | return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states) 90 | 91 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): 92 | if inputs_embeds is not None: 93 | raise NotImplementedError('inputs_embeds is not implemented for MPT yet') 94 | attention_mask = kwargs['attention_mask'].bool() 95 | if attention_mask[:, -1].sum() != attention_mask.shape[0]: 96 | raise NotImplementedError('MPT does not support generation with right padding.') 97 | if self.transformer.attn_uses_sequence_id and self.training: 98 | sequence_id = torch.zeros_like(input_ids[:1]) 99 | else: 100 | sequence_id = None 101 | if past_key_values is not None: 102 | input_ids = input_ids[:, -1].unsqueeze(-1) 103 | if self.transformer.prefix_lm: 104 | prefix_mask = torch.ones_like(attention_mask) 105 | if kwargs.get('use_cache') == False: 106 | raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.') 107 | else: 108 | prefix_mask = None 109 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True), "images": kwargs.get("images", None)} 110 | 111 | 112 | AutoConfig.register("llava_mpt", LlavaMPTConfig) 113 | AutoModelForCausalLM.register(LlavaMPTConfig, LlavaMPTForCausalLM) 114 | -------------------------------------------------------------------------------- /llava/model/language_model/mpt/adapt_tokenizer.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast 3 | Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] 4 | NUM_SENTINEL_TOKENS: int = 100 5 | 6 | def adapt_tokenizer_for_denoising(tokenizer: Tokenizer): 7 | """Adds sentinel tokens and padding token (if missing). 8 | 9 | Expands the tokenizer vocabulary to include sentinel tokens 10 | used in mixture-of-denoiser tasks as well as a padding token. 11 | 12 | All added tokens are added as special tokens. No tokens are 13 | added if sentinel tokens and padding token already exist. 14 | """ 15 | sentinels_to_add = [f'' for i in range(NUM_SENTINEL_TOKENS)] 16 | tokenizer.add_tokens(sentinels_to_add, special_tokens=True) 17 | if tokenizer.pad_token is None: 18 | tokenizer.add_tokens('', special_tokens=True) 19 | tokenizer.pad_token = '' 20 | assert tokenizer.pad_token_id is not None 21 | sentinels = ''.join([f'' for i in range(NUM_SENTINEL_TOKENS)]) 22 | _sentinel_token_ids = tokenizer(sentinels, add_special_tokens=False).input_ids 23 | tokenizer.sentinel_token_ids = _sentinel_token_ids 24 | 25 | class AutoTokenizerForMOD(AutoTokenizer): 26 | """AutoTokenizer + Adaptation for MOD. 27 | 28 | A simple wrapper around AutoTokenizer to make instantiating 29 | an MOD-adapted tokenizer a bit easier. 30 | 31 | MOD-adapted tokenizers have sentinel tokens (e.g., ), 32 | a padding token, and a property to get the token ids of the 33 | sentinel tokens. 34 | """ 35 | 36 | @classmethod 37 | def from_pretrained(cls, *args, **kwargs): 38 | """See `AutoTokenizer.from_pretrained` docstring.""" 39 | tokenizer = super().from_pretrained(*args, **kwargs) 40 | adapt_tokenizer_for_denoising(tokenizer) 41 | return tokenizer -------------------------------------------------------------------------------- /llava/model/language_model/mpt/blocks.py: -------------------------------------------------------------------------------- 1 | """GPT Blocks used for the GPT Model.""" 2 | from typing import Dict, Optional, Tuple 3 | import torch 4 | import torch.nn as nn 5 | from .attention import ATTN_CLASS_REGISTRY 6 | from .norm import NORM_CLASS_REGISTRY 7 | 8 | class MPTMLP(nn.Module): 9 | 10 | def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None): 11 | super().__init__() 12 | self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device) 13 | self.act = nn.GELU(approximate='none') 14 | self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device) 15 | self.down_proj._is_residual = True 16 | 17 | def forward(self, x): 18 | return self.down_proj(self.act(self.up_proj(x))) 19 | 20 | class MPTBlock(nn.Module): 21 | 22 | def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs): 23 | del kwargs 24 | super().__init__() 25 | norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] 26 | attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']] 27 | self.norm_1 = norm_class(d_model, device=device) 28 | self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device) 29 | self.norm_2 = norm_class(d_model, device=device) 30 | self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device) 31 | self.resid_attn_dropout = nn.Dropout(resid_pdrop) 32 | self.resid_ffn_dropout = nn.Dropout(resid_pdrop) 33 | 34 | def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: 35 | a = self.norm_1(x) 36 | (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal) 37 | x = x + self.resid_attn_dropout(b) 38 | m = self.norm_2(x) 39 | n = self.ffn(m) 40 | x = x + self.resid_ffn_dropout(n) 41 | return (x, attn_weights, past_key_value) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/configuration_mpt.py: -------------------------------------------------------------------------------- 1 | """A HuggingFace-style model configuration.""" 2 | from typing import Dict, Optional, Union 3 | from transformers import PretrainedConfig 4 | attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8} 5 | init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0} 6 | 7 | class MPTConfig(PretrainedConfig): 8 | model_type = 'mpt' 9 | 10 | def __init__(self, d_model: int=2048, n_heads: int=16, n_layers: int=24, expansion_ratio: int=4, max_seq_len: int=2048, vocab_size: int=50368, resid_pdrop: float=0.0, emb_pdrop: float=0.0, learned_pos_emb: bool=True, attn_config: Dict=attn_config_defaults, init_device: str='cpu', logit_scale: Optional[Union[float, str]]=None, no_bias: bool=False, verbose: int=0, embedding_fraction: float=1.0, norm_type: str='low_precision_layernorm', use_cache: bool=False, init_config: Dict=init_config_defaults, **kwargs): 11 | """The MPT configuration class. 12 | 13 | Args: 14 | d_model (int): The size of the embedding dimension of the model. 15 | n_heads (int): The number of attention heads. 16 | n_layers (int): The number of layers in the model. 17 | expansion_ratio (int): The ratio of the up/down scale in the MLP. 18 | max_seq_len (int): The maximum sequence length of the model. 19 | vocab_size (int): The size of the vocabulary. 20 | resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. 21 | emb_pdrop (float): The dropout probability for the embedding layer. 22 | learned_pos_emb (bool): Whether to use learned positional embeddings 23 | attn_config (Dict): A dictionary used to configure the model's attention module: 24 | attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention 25 | attn_pdrop (float): The dropout probability for the attention layers. 26 | attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. 27 | qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. 28 | clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to 29 | this value. 30 | softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, 31 | use the default scale of ``1/sqrt(d_keys)``. 32 | prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an 33 | extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix 34 | can attend to one another bi-directionally. Tokens outside the prefix use causal attention. 35 | attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. 36 | When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates 37 | which sub-sequence each token belongs to. 38 | Defaults to ``False`` meaning any provided `sequence_id` will be ignored. 39 | alibi (bool): Whether to use the alibi bias instead of position embeddings. 40 | alibi_bias_max (int): The maximum value of the alibi bias. 41 | init_device (str): The device to use for parameter initialization. 42 | logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. 43 | no_bias (bool): Whether to use bias in all layers. 44 | verbose (int): The verbosity level. 0 is silent. 45 | embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. 46 | norm_type (str): choose type of norm to use 47 | multiquery_attention (bool): Whether to use multiquery attention implementation. 48 | use_cache (bool): Whether or not the model should return the last key/values attentions 49 | init_config (Dict): A dictionary used to configure the model initialization: 50 | init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', 51 | 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or 52 | 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. 53 | init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. 54 | emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. 55 | emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution 56 | used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. 57 | init_std (float): The standard deviation of the normal distribution used to initialize the model, 58 | if using the baseline_ parameter initialization scheme. 59 | init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. 60 | fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. 61 | init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. 62 | --- 63 | See llmfoundry.models.utils.param_init_fns.py for info on other param init config options 64 | """ 65 | self.d_model = d_model 66 | self.n_heads = n_heads 67 | self.n_layers = n_layers 68 | self.expansion_ratio = expansion_ratio 69 | self.max_seq_len = max_seq_len 70 | self.vocab_size = vocab_size 71 | self.resid_pdrop = resid_pdrop 72 | self.emb_pdrop = emb_pdrop 73 | self.learned_pos_emb = learned_pos_emb 74 | self.attn_config = attn_config 75 | self.init_device = init_device 76 | self.logit_scale = logit_scale 77 | self.no_bias = no_bias 78 | self.verbose = verbose 79 | self.embedding_fraction = embedding_fraction 80 | self.norm_type = norm_type 81 | self.use_cache = use_cache 82 | self.init_config = init_config 83 | if 'name' in kwargs: 84 | del kwargs['name'] 85 | if 'loss_fn' in kwargs: 86 | del kwargs['loss_fn'] 87 | super().__init__(**kwargs) 88 | self._validate_config() 89 | 90 | def _set_config_defaults(self, config, config_defaults): 91 | for (k, v) in config_defaults.items(): 92 | if k not in config: 93 | config[k] = v 94 | return config 95 | 96 | def _validate_config(self): 97 | self.attn_config = self._set_config_defaults(self.attn_config, attn_config_defaults) 98 | self.init_config = self._set_config_defaults(self.init_config, init_config_defaults) 99 | if self.d_model % self.n_heads != 0: 100 | raise ValueError('d_model must be divisible by n_heads') 101 | if any((prob < 0 or prob > 1 for prob in [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop])): 102 | raise ValueError("self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1") 103 | if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: 104 | raise ValueError(f"Unknown attn_impl={self.attn_config['attn_impl']}") 105 | if self.attn_config['prefix_lm'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 106 | raise NotImplementedError('prefix_lm only implemented with torch and triton attention.') 107 | if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 108 | raise NotImplementedError('alibi only implemented with torch and triton attention.') 109 | if self.attn_config['attn_uses_sequence_id'] and self.attn_config['attn_impl'] not in ['torch', 'triton']: 110 | raise NotImplementedError('attn_uses_sequence_id only implemented with torch and triton attention.') 111 | if self.embedding_fraction > 1 or self.embedding_fraction <= 0: 112 | raise ValueError('model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!') 113 | if isinstance(self.logit_scale, str) and self.logit_scale != 'inv_sqrt_d_model': 114 | raise ValueError(f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.") 115 | if self.init_config.get('name', None) is None: 116 | raise ValueError(f"self.init_config={self.init_config!r} 'name' needs to be set.") 117 | if not self.learned_pos_emb and (not self.attn_config['alibi']): 118 | raise ValueError(f'Positional information must be provided to the model using either learned_pos_emb or alibi.') -------------------------------------------------------------------------------- /llava/model/language_model/mpt/custom_embedding.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch import Tensor 5 | 6 | class SharedEmbedding(nn.Embedding): 7 | 8 | def forward(self, input: Tensor, unembed: bool=False) -> Tensor: 9 | if unembed: 10 | return F.linear(input, self.weight) 11 | return super().forward(input) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/meta_init_context.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import torch 3 | import torch.nn as nn 4 | 5 | @contextmanager 6 | def init_empty_weights(include_buffers: bool=False): 7 | """Meta initialization context manager. 8 | 9 | A context manager under which models are initialized with all parameters 10 | on the meta device, therefore creating an empty model. Useful when just 11 | initializing the model would blow the available RAM. 12 | 13 | Args: 14 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 15 | not to also put all buffers on the meta device while initializing. 16 | 17 | Example: 18 | ```python 19 | import torch.nn as nn 20 | 21 | # Initialize a model with 100 billions parameters in no time and without using any RAM. 22 | with init_empty_weights(): 23 | tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) 24 | ``` 25 | 26 | 27 | 28 | Any model created under this context manager has no weights. As such you can't do something like 29 | `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. 30 | 31 | 32 | """ 33 | with init_on_device(torch.device('meta'), include_buffers=include_buffers) as f: 34 | yield f 35 | 36 | @contextmanager 37 | def init_on_device(device: torch.device, include_buffers: bool=False): 38 | """Device initialization context manager. 39 | 40 | A context manager under which models are initialized with all parameters 41 | on the specified device. 42 | 43 | Args: 44 | device (`torch.device`): Device to initialize all parameters on. 45 | include_buffers (`bool`, *optional*, defaults to `False`): Whether or 46 | not to also put all buffers on the meta device while initializing. 47 | 48 | Example: 49 | ```python 50 | import torch.nn as nn 51 | 52 | with init_on_device(device=torch.device("cuda")): 53 | tst = nn.Liner(100, 100) # on `cuda` device 54 | ``` 55 | """ 56 | old_register_parameter = nn.Module.register_parameter 57 | if include_buffers: 58 | old_register_buffer = nn.Module.register_buffer 59 | 60 | def register_empty_parameter(module, name, param): 61 | old_register_parameter(module, name, param) 62 | if param is not None: 63 | param_cls = type(module._parameters[name]) 64 | kwargs = module._parameters[name].__dict__ 65 | module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) 66 | 67 | def register_empty_buffer(module, name, buffer): 68 | old_register_buffer(module, name, buffer) 69 | if buffer is not None: 70 | module._buffers[name] = module._buffers[name].to(device) 71 | if include_buffers: 72 | tensor_constructors_to_patch = {torch_function_name: getattr(torch, torch_function_name) for torch_function_name in ['empty', 'zeros', 'ones', 'full']} 73 | else: 74 | tensor_constructors_to_patch = {} 75 | 76 | def patch_tensor_constructor(fn): 77 | 78 | def wrapper(*args, **kwargs): 79 | kwargs['device'] = device 80 | return fn(*args, **kwargs) 81 | return wrapper 82 | try: 83 | nn.Module.register_parameter = register_empty_parameter 84 | if include_buffers: 85 | nn.Module.register_buffer = register_empty_buffer 86 | for torch_function_name in tensor_constructors_to_patch.keys(): 87 | setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) 88 | yield 89 | finally: 90 | nn.Module.register_parameter = old_register_parameter 91 | if include_buffers: 92 | nn.Module.register_buffer = old_register_buffer 93 | for (torch_function_name, old_torch_function) in tensor_constructors_to_patch.items(): 94 | setattr(torch, torch_function_name, old_torch_function) -------------------------------------------------------------------------------- /llava/model/language_model/mpt/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def _cast_if_autocast_enabled(tensor): 4 | if torch.is_autocast_enabled(): 5 | if tensor.device.type == 'cuda': 6 | dtype = torch.get_autocast_gpu_dtype() 7 | elif tensor.device.type == 'cpu': 8 | dtype = torch.get_autocast_cpu_dtype() 9 | else: 10 | raise NotImplementedError() 11 | return tensor.to(dtype=dtype) 12 | return tensor 13 | 14 | class LPLayerNorm(torch.nn.LayerNorm): 15 | 16 | def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): 17 | super().__init__(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine, device=device, dtype=dtype) 18 | 19 | def forward(self, x): 20 | module_device = x.device 21 | downcast_x = _cast_if_autocast_enabled(x) 22 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 23 | downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias 24 | with torch.autocast(enabled=False, device_type=module_device.type): 25 | return torch.nn.functional.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) 26 | 27 | def rms_norm(x, weight=None, eps=1e-05): 28 | output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) 29 | if weight is not None: 30 | return output * weight 31 | return output 32 | 33 | class RMSNorm(torch.nn.Module): 34 | 35 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 36 | super().__init__() 37 | self.eps = eps 38 | if weight: 39 | self.weight = torch.nn.Parameter(torch.ones(normalized_shape, dtype=dtype, device=device)) 40 | else: 41 | self.register_parameter('weight', None) 42 | 43 | def forward(self, x): 44 | return rms_norm(x.float(), self.weight, self.eps).to(dtype=x.dtype) 45 | 46 | class LPRMSNorm(RMSNorm): 47 | 48 | def __init__(self, normalized_shape, eps=1e-05, weight=True, dtype=None, device=None): 49 | super().__init__(normalized_shape=normalized_shape, eps=eps, weight=weight, dtype=dtype, device=device) 50 | 51 | def forward(self, x): 52 | downcast_x = _cast_if_autocast_enabled(x) 53 | downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight 54 | with torch.autocast(enabled=False, device_type=x.device.type): 55 | return rms_norm(downcast_x, downcast_weight, self.eps).to(dtype=x.dtype) 56 | NORM_CLASS_REGISTRY = {'layernorm': torch.nn.LayerNorm, 'low_precision_layernorm': LPLayerNorm, 'rmsnorm': RMSNorm, 'low_precision_rmsnorm': LPRMSNorm} -------------------------------------------------------------------------------- /llava/model/language_model/mpt/param_init_fns.py: -------------------------------------------------------------------------------- 1 | import math 2 | import warnings 3 | from collections.abc import Sequence 4 | from functools import partial 5 | from typing import Optional, Tuple, Union 6 | import torch 7 | from torch import nn 8 | from .norm import NORM_CLASS_REGISTRY 9 | 10 | def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs): 11 | del kwargs 12 | if verbose > 1: 13 | warnings.warn(f"Initializing network using module's reset_parameters attribute") 14 | if hasattr(module, 'reset_parameters'): 15 | module.reset_parameters() 16 | 17 | def fused_init_helper_(module: nn.Module, init_fn_): 18 | _fused = getattr(module, '_fused', None) 19 | if _fused is None: 20 | raise RuntimeError(f'Internal logic error') 21 | (dim, splits) = _fused 22 | splits = (0, *splits, module.weight.size(dim)) 23 | for (s, e) in zip(splits[:-1], splits[1:]): 24 | slice_indices = [slice(None)] * module.weight.ndim 25 | slice_indices[dim] = slice(s, e) 26 | init_fn_(module.weight[slice_indices]) 27 | 28 | def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 29 | del kwargs 30 | if verbose > 1: 31 | warnings.warn(f'If model has bias parameters they are initialized to 0.') 32 | init_div_is_residual = init_div_is_residual 33 | if init_div_is_residual is False: 34 | div_is_residual = 1.0 35 | elif init_div_is_residual is True: 36 | div_is_residual = math.sqrt(2 * n_layers) 37 | elif isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int): 38 | div_is_residual = init_div_is_residual 39 | elif isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric(): 40 | div_is_residual = float(init_div_is_residual) 41 | else: 42 | div_is_residual = 1.0 43 | raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}') 44 | if init_div_is_residual is not False: 45 | if verbose > 1: 46 | warnings.warn(f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.') 47 | if isinstance(module, nn.Linear): 48 | if hasattr(module, '_fused'): 49 | fused_init_helper_(module, init_fn_) 50 | else: 51 | init_fn_(module.weight) 52 | if module.bias is not None: 53 | torch.nn.init.zeros_(module.bias) 54 | if init_div_is_residual is not False and getattr(module, '_is_residual', False): 55 | with torch.no_grad(): 56 | module.weight.div_(div_is_residual) 57 | elif isinstance(module, nn.Embedding): 58 | if emb_init_std is not None: 59 | std = emb_init_std 60 | if std == 0: 61 | warnings.warn(f'Embedding layer initialized to 0.') 62 | emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std) 63 | if verbose > 1: 64 | warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.') 65 | elif emb_init_uniform_lim is not None: 66 | lim = emb_init_uniform_lim 67 | if isinstance(lim, Sequence): 68 | if len(lim) > 2: 69 | raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.') 70 | if lim[0] == lim[1]: 71 | warnings.warn(f'Embedding layer initialized to {lim[0]}.') 72 | else: 73 | if lim == 0: 74 | warnings.warn(f'Embedding layer initialized to 0.') 75 | lim = [-lim, lim] 76 | (a, b) = lim 77 | emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b) 78 | if verbose > 1: 79 | warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.') 80 | else: 81 | emb_init_fn_ = init_fn_ 82 | emb_init_fn_(module.weight) 83 | elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): 84 | if verbose > 1: 85 | warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.') 86 | if hasattr(module, 'weight') and module.weight is not None: 87 | torch.nn.init.ones_(module.weight) 88 | if hasattr(module, 'bias') and module.bias is not None: 89 | torch.nn.init.zeros_(module.bias) 90 | elif isinstance(module, nn.MultiheadAttention): 91 | if module._qkv_same_embed_dim: 92 | assert module.in_proj_weight is not None 93 | assert module.q_proj_weight is None and module.k_proj_weight is None and (module.v_proj_weight is None) 94 | assert d_model is not None 95 | _d = d_model 96 | splits = (0, _d, 2 * _d, 3 * _d) 97 | for (s, e) in zip(splits[:-1], splits[1:]): 98 | init_fn_(module.in_proj_weight[s:e]) 99 | else: 100 | assert module.q_proj_weight is not None and module.k_proj_weight is not None and (module.v_proj_weight is not None) 101 | assert module.in_proj_weight is None 102 | init_fn_(module.q_proj_weight) 103 | init_fn_(module.k_proj_weight) 104 | init_fn_(module.v_proj_weight) 105 | if module.in_proj_bias is not None: 106 | torch.nn.init.zeros_(module.in_proj_bias) 107 | if module.bias_k is not None: 108 | torch.nn.init.zeros_(module.bias_k) 109 | if module.bias_v is not None: 110 | torch.nn.init.zeros_(module.bias_v) 111 | init_fn_(module.out_proj.weight) 112 | if init_div_is_residual is not False and getattr(module.out_proj, '_is_residual', False): 113 | with torch.no_grad(): 114 | module.out_proj.weight.div_(div_is_residual) 115 | if module.out_proj.bias is not None: 116 | torch.nn.init.zeros_(module.out_proj.bias) 117 | else: 118 | for _ in module.parameters(recurse=False): 119 | raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.') 120 | 121 | def _normal_init_(std, mean=0.0): 122 | return partial(torch.nn.init.normal_, mean=mean, std=std) 123 | 124 | def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 125 | del kwargs 126 | init_fn_ = _normal_init_(std=std) 127 | if verbose > 1: 128 | warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}') 129 | generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 130 | 131 | def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 132 | del kwargs 133 | if init_std is None: 134 | raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.") 135 | _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 136 | 137 | def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 138 | del kwargs 139 | std = math.sqrt(2 / (5 * d_model)) 140 | _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 141 | 142 | def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, verbose: int=0, **kwargs): 143 | """From section 2.3.1 of GPT-NeoX-20B: 144 | 145 | An Open-Source AutoregressiveLanguage Model — Black et. al. (2022) 146 | see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151 147 | and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py 148 | """ 149 | del kwargs 150 | residual_div = n_layers / math.sqrt(10) 151 | if verbose > 1: 152 | warnings.warn(f'setting init_div_is_residual to {residual_div}') 153 | small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 154 | 155 | def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): 156 | del kwargs 157 | if verbose > 1: 158 | warnings.warn(f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') 159 | kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) 160 | generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 161 | 162 | def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): 163 | del kwargs 164 | if verbose > 1: 165 | warnings.warn(f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}') 166 | kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) 167 | generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 168 | 169 | def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): 170 | del kwargs 171 | xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) 172 | if verbose > 1: 173 | warnings.warn(f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}') 174 | generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 175 | 176 | def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[int, float, str, bool]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]]=None, init_gain: float=0, verbose: int=0, **kwargs): 177 | xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) 178 | if verbose > 1: 179 | warnings.warn(f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}') 180 | generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) 181 | MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_} -------------------------------------------------------------------------------- /llava/model/llava_arch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from abc import ABC, abstractmethod 17 | 18 | import torch 19 | import torch.nn as nn 20 | 21 | from .multimodal_encoder.builder import build_vision_tower 22 | from .multimodal_projector.builder import build_vision_projector 23 | 24 | from llava.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 25 | 26 | 27 | class LlavaMetaModel: 28 | 29 | def __init__(self, config): 30 | super(LlavaMetaModel, self).__init__(config) 31 | 32 | if hasattr(config, "mm_vision_tower"): 33 | self.vision_tower = build_vision_tower(config, delay_load=True) 34 | self.mm_projector = build_vision_projector(config) 35 | 36 | def get_vision_tower(self): 37 | vision_tower = getattr(self, 'vision_tower', None) 38 | if type(vision_tower) is list: 39 | vision_tower = vision_tower[0] 40 | return vision_tower 41 | 42 | def initialize_vision_modules(self, model_args, fsdp=None): 43 | vision_tower = model_args.vision_tower 44 | mm_vision_select_layer = model_args.mm_vision_select_layer 45 | mm_vision_select_feature = model_args.mm_vision_select_feature 46 | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter 47 | 48 | self.config.mm_vision_tower = vision_tower 49 | 50 | if self.get_vision_tower() is None: 51 | vision_tower = build_vision_tower(model_args) 52 | 53 | if fsdp is not None and len(fsdp) > 0: 54 | self.vision_tower = [vision_tower] 55 | else: 56 | self.vision_tower = vision_tower 57 | else: 58 | if fsdp is not None and len(fsdp) > 0: 59 | vision_tower = self.vision_tower[0] 60 | else: 61 | vision_tower = self.vision_tower 62 | vision_tower.load_model() 63 | 64 | self.config.use_mm_proj = True 65 | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') 66 | self.config.mm_hidden_size = vision_tower.hidden_size 67 | self.config.mm_vision_select_layer = mm_vision_select_layer 68 | self.config.mm_vision_select_feature = mm_vision_select_feature 69 | 70 | if getattr(self, 'mm_projector', None) is None: 71 | self.mm_projector = build_vision_projector(self.config) 72 | else: 73 | # In case it is frozen by LoRA 74 | for p in self.mm_projector.parameters(): 75 | p.requires_grad = True 76 | 77 | if pretrain_mm_mlp_adapter is not None: 78 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') 79 | def get_w(weights, keyword): 80 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 81 | 82 | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) 83 | 84 | 85 | class LlavaMetaForCausalLM(ABC): 86 | 87 | @abstractmethod 88 | def get_model(self): 89 | pass 90 | 91 | def get_vision_tower(self): 92 | return self.get_model().get_vision_tower() 93 | 94 | def encode_images(self, images): 95 | image_features = self.get_model().get_vision_tower()(images) 96 | image_features = self.get_model().mm_projector(image_features) 97 | return image_features 98 | 99 | def prepare_inputs_labels_for_multimodal( 100 | self, input_ids, attention_mask, past_key_values, labels, images 101 | ): 102 | vision_tower = self.get_vision_tower() 103 | if vision_tower is None or images is None or input_ids.shape[1] == 1: 104 | if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: 105 | attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) 106 | return input_ids, attention_mask, past_key_values, None, labels 107 | 108 | if type(images) is list or images.ndim == 5: 109 | concat_images = torch.cat([image for image in images], dim=0) 110 | image_features = self.encode_images(concat_images) 111 | split_sizes = [image.shape[0] for image in images] 112 | image_features = torch.split(image_features, split_sizes, dim=0) 113 | image_features = [x.flatten(0, 1) for x in image_features] 114 | else: 115 | image_features = self.encode_images(images) 116 | 117 | new_input_embeds = [] 118 | new_labels = [] if labels is not None else None 119 | cur_image_idx = 0 120 | for batch_idx, cur_input_ids in enumerate(input_ids): 121 | if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: 122 | # multimodal LLM, but the current sample is not multimodal 123 | # FIXME: this is a hacky fix, for deepspeed zero3 to work 124 | half_len = cur_input_ids.shape[0] // 2 125 | cur_image_features = image_features[cur_image_idx] 126 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids[:half_len]) 127 | cur_input_embeds_2 = self.get_model().embed_tokens(cur_input_ids[half_len:]) 128 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0], cur_input_embeds_2], dim=0) 129 | new_input_embeds.append(cur_input_embeds) 130 | if labels is not None: 131 | new_labels.append(labels[batch_idx]) 132 | cur_image_idx += 1 133 | continue 134 | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] 135 | cur_new_input_embeds = [] 136 | if labels is not None: 137 | cur_labels = labels[batch_idx] 138 | cur_new_labels = [] 139 | assert cur_labels.shape == cur_input_ids.shape 140 | while image_token_indices.numel() > 0: 141 | cur_image_features = image_features[cur_image_idx] 142 | image_token_start = image_token_indices[0] 143 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): 144 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start-1]).detach()) 145 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start-1:image_token_start])) 146 | cur_new_input_embeds.append(cur_image_features) 147 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[image_token_start+1:image_token_start+2])) 148 | if labels is not None: 149 | cur_new_labels.append(cur_labels[:image_token_start]) 150 | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) 151 | cur_new_labels.append(cur_labels[image_token_start+1:image_token_start+2]) 152 | cur_labels = cur_labels[image_token_start+2:] 153 | else: 154 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids[:image_token_start])) 155 | cur_new_input_embeds.append(cur_image_features) 156 | if labels is not None: 157 | cur_new_labels.append(cur_labels[:image_token_start]) 158 | cur_new_labels.append(torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=labels.device, dtype=labels.dtype)) 159 | cur_labels = cur_labels[image_token_start+1:] 160 | cur_image_idx += 1 161 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): 162 | cur_input_ids = cur_input_ids[image_token_start+2:] 163 | else: 164 | cur_input_ids = cur_input_ids[image_token_start+1:] 165 | image_token_indices = torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0] 166 | if cur_input_ids.numel() > 0: 167 | if getattr(self.config, 'tune_mm_mlp_adapter', False) and getattr(self.config, 'mm_use_im_start_end', False): 168 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids).detach()) 169 | else: 170 | cur_new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids)) 171 | if labels is not None: 172 | cur_new_labels.append(cur_labels) 173 | cur_new_input_embeds = [x.to(device=self.device) for x in cur_new_input_embeds] 174 | cur_new_input_embeds = torch.cat(cur_new_input_embeds, dim=0) 175 | new_input_embeds.append(cur_new_input_embeds) 176 | if labels is not None: 177 | cur_new_labels = torch.cat(cur_new_labels, dim=0) 178 | new_labels.append(cur_new_labels) 179 | 180 | if any(x.shape != new_input_embeds[0].shape for x in new_input_embeds): 181 | max_len = max(x.shape[0] for x in new_input_embeds) 182 | 183 | new_input_embeds_align = [] 184 | for cur_new_embed in new_input_embeds: 185 | cur_new_embed = torch.cat((cur_new_embed, torch.zeros((max_len - cur_new_embed.shape[0], cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0) 186 | new_input_embeds_align.append(cur_new_embed) 187 | new_input_embeds = torch.stack(new_input_embeds_align, dim=0) 188 | 189 | if labels is not None: 190 | new_labels_align = [] 191 | _new_labels = new_labels 192 | for cur_new_label in new_labels: 193 | cur_new_label = torch.cat((cur_new_label, torch.full((max_len - cur_new_label.shape[0],), IGNORE_INDEX, dtype=cur_new_label.dtype, device=cur_new_label.device)), dim=0) 194 | new_labels_align.append(cur_new_label) 195 | new_labels = torch.stack(new_labels_align, dim=0) 196 | 197 | if attention_mask is not None: 198 | new_attention_mask = [] 199 | for cur_attention_mask, cur_new_labels, cur_new_labels_align in zip(attention_mask, _new_labels, new_labels): 200 | new_attn_mask_pad_left = torch.full((cur_new_labels.shape[0] - labels.shape[1],), True, dtype=attention_mask.dtype, device=attention_mask.device) 201 | new_attn_mask_pad_right = torch.full((cur_new_labels_align.shape[0] - cur_new_labels.shape[0],), False, dtype=attention_mask.dtype, device=attention_mask.device) 202 | cur_new_attention_mask = torch.cat((new_attn_mask_pad_left, cur_attention_mask, new_attn_mask_pad_right), dim=0) 203 | new_attention_mask.append(cur_new_attention_mask) 204 | attention_mask = torch.stack(new_attention_mask, dim=0) 205 | assert attention_mask.shape == new_labels.shape 206 | else: 207 | new_input_embeds = torch.stack(new_input_embeds, dim=0) 208 | if labels is not None: 209 | new_labels = torch.stack(new_labels, dim=0) 210 | 211 | if attention_mask is not None: 212 | new_attn_mask_pad_left = torch.full((attention_mask.shape[0], new_input_embeds.shape[1] - input_ids.shape[1]), True, dtype=attention_mask.dtype, device=attention_mask.device) 213 | attention_mask = torch.cat((new_attn_mask_pad_left, attention_mask), dim=1) 214 | assert attention_mask.shape == new_input_embeds.shape[:2] 215 | 216 | return None, attention_mask, past_key_values, new_input_embeds, new_labels 217 | 218 | def initialize_vision_tokenizer(self, model_args, tokenizer): 219 | if model_args.mm_use_im_patch_token: 220 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 221 | self.resize_token_embeddings(len(tokenizer)) 222 | 223 | if model_args.mm_use_im_start_end: 224 | num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 225 | self.resize_token_embeddings(len(tokenizer)) 226 | 227 | if num_new_tokens > 0: 228 | input_embeddings = self.get_input_embeddings().weight.data 229 | output_embeddings = self.get_output_embeddings().weight.data 230 | 231 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 232 | dim=0, keepdim=True) 233 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 234 | dim=0, keepdim=True) 235 | 236 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 237 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 238 | 239 | if model_args.tune_mm_mlp_adapter: 240 | for p in self.get_input_embeddings().parameters(): 241 | p.requires_grad = True 242 | for p in self.get_output_embeddings().parameters(): 243 | p.requires_grad = False 244 | 245 | if model_args.pretrain_mm_mlp_adapter: 246 | mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') 247 | embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] 248 | assert num_new_tokens == 2 249 | if input_embeddings.shape == embed_tokens_weight.shape: 250 | input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] 251 | elif embed_tokens_weight.shape[0] == num_new_tokens: 252 | input_embeddings[-num_new_tokens:] = embed_tokens_weight 253 | else: 254 | raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") 255 | elif model_args.mm_use_im_patch_token: 256 | if model_args.tune_mm_mlp_adapter: 257 | for p in self.get_input_embeddings().parameters(): 258 | p.requires_grad = False 259 | for p in self.get_output_embeddings().parameters(): 260 | p.requires_grad = False 261 | -------------------------------------------------------------------------------- /llava/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from llava.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .clip_encoder import CLIPVisionTower 3 | 4 | 5 | def build_vision_tower(vision_tower_cfg, **kwargs): 6 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 7 | is_absolute_path_exists = os.path.exists(vision_tower) 8 | if is_absolute_path_exists or vision_tower.startswith("openai") or vision_tower.startswith("laion"): 9 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 10 | 11 | raise ValueError(f'Unknown vision tower: {vision_tower}') 12 | -------------------------------------------------------------------------------- /llava/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = args.mm_vision_select_layer 15 | self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') 16 | 17 | if not delay_load: 18 | self.load_model() 19 | else: 20 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 21 | 22 | def load_model(self): 23 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 24 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 25 | self.vision_tower.requires_grad_(False) 26 | 27 | self.is_loaded = True 28 | 29 | def feature_select(self, image_forward_outs): 30 | image_features = image_forward_outs.hidden_states[self.select_layer] 31 | if self.select_feature == 'patch': 32 | image_features = image_features[:, 1:] 33 | elif self.select_feature == 'cls_patch': 34 | image_features = image_features 35 | else: 36 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 37 | return image_features 38 | 39 | @torch.no_grad() 40 | def forward(self, images): 41 | if type(images) is list: 42 | image_features = [] 43 | for image in images: 44 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 45 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 46 | image_features.append(image_feature) 47 | else: 48 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 49 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 50 | 51 | return image_features 52 | 53 | @property 54 | def dummy_feature(self): 55 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 56 | 57 | @property 58 | def dtype(self): 59 | return self.vision_tower.dtype 60 | 61 | @property 62 | def device(self): 63 | return self.vision_tower.device 64 | 65 | @property 66 | def config(self): 67 | if self.is_loaded: 68 | return self.vision_tower.config 69 | else: 70 | return self.cfg_only 71 | 72 | @property 73 | def hidden_size(self): 74 | return self.config.hidden_size 75 | 76 | @property 77 | def num_patches(self): 78 | return (self.config.image_size // self.config.patch_size) ** 2 79 | -------------------------------------------------------------------------------- /llava/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(config, delay_load=False, **kwargs): 34 | projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 38 | 39 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 40 | if mlp_gelu_match: 41 | mlp_depth = int(mlp_gelu_match.group(1)) 42 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 43 | for _ in range(1, mlp_depth): 44 | modules.append(nn.GELU()) 45 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 46 | return nn.Sequential(*modules) 47 | 48 | if projector_type == 'identity': 49 | return IdentityMap() 50 | 51 | raise ValueError(f'Unknown projector type: {projector_type}') 52 | -------------------------------------------------------------------------------- /llava/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | 3 | 4 | def auto_upgrade(config): 5 | cfg = AutoConfig.from_pretrained(config) 6 | if 'llava' in config and 'llava' not in cfg.model_type: 7 | assert cfg.model_type == 'llama' 8 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 9 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 10 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 11 | if confirm.lower() in ["y", "yes"]: 12 | print("Upgrading checkpoint...") 13 | assert len(cfg.architectures) == 1 14 | setattr(cfg.__class__, "model_type", "llava") 15 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 16 | cfg.save_pretrained(config) 17 | print("Checkpoint upgraded.") 18 | else: 19 | print("Checkpoint upgrade aborted.") 20 | exit(1) 21 | -------------------------------------------------------------------------------- /llava/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /llava/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import transformers.models.llama.modeling_llama 11 | from torch import nn 12 | 13 | try: 14 | import xformers.ops 15 | except ImportError: 16 | logging.error("xformers not found! Please install it before trying to use it.") 17 | 18 | 19 | def replace_llama_attn_with_xformers_attn(): 20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 21 | 22 | 23 | def xformers_forward( 24 | self, 25 | hidden_states: torch.Tensor, 26 | attention_mask: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.LongTensor] = None, 28 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 29 | output_attentions: bool = False, 30 | use_cache: bool = False, 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 32 | # pylint: disable=duplicate-code 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | 51 | kv_seq_len = key_states.shape[-2] 52 | if past_key_value is not None: 53 | kv_seq_len += past_key_value[0].shape[-2] 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | ( 56 | query_states, 57 | key_states, 58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | 63 | if past_key_value is not None: 64 | # reuse k, v, self_attention 65 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 66 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 67 | 68 | past_key_value = (key_states, value_states) if use_cache else None 69 | 70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 71 | if not output_attentions: 72 | query_states = query_states.transpose(1, 2) 73 | key_states = key_states.transpose(1, 2) 74 | value_states = value_states.transpose(1, 2) 75 | 76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 79 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 80 | attn_output = xformers.ops.memory_efficient_attention( 81 | query_states, key_states, value_states, attn_bias=None 82 | ) 83 | else: 84 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 85 | attn_output = xformers.ops.memory_efficient_attention( 86 | query_states, 87 | key_states, 88 | value_states, 89 | attn_bias=xformers.ops.LowerTriangularMask(), 90 | ) 91 | attn_weights = None 92 | else: 93 | attn_weights = torch.matmul( 94 | query_states, key_states.transpose(2, 3) 95 | ) / math.sqrt(self.head_dim) 96 | 97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 100 | f" {attn_weights.size()}" 101 | ) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 105 | raise ValueError( 106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 107 | ) 108 | attn_weights = attn_weights + attention_mask 109 | attn_weights = torch.max( 110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 111 | ) 112 | 113 | # upcast attention to fp32 114 | attn_weights = nn.functional.softmax( 115 | attn_weights, dim=-1, dtype=torch.float32 116 | ).to(query_states.dtype) 117 | attn_output = torch.matmul(attn_weights, value_states) 118 | 119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 120 | raise ValueError( 121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 122 | f" {attn_output.size()}" 123 | ) 124 | 125 | attn_output = attn_output.transpose(1, 2) 126 | 127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 128 | attn_output = self.o_proj(attn_output) 129 | return attn_output, attn_weights, past_key_value 130 | -------------------------------------------------------------------------------- /llava/train/llava_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from torch.utils.data import Sampler 5 | 6 | from transformers import Trainer 7 | from transformers.trainer import ( 8 | is_sagemaker_mp_enabled, 9 | get_parameter_names, 10 | has_length, 11 | ALL_LAYERNORM_LAYERS, 12 | ShardedDDPOption, 13 | logger, 14 | ) 15 | from typing import List, Optional 16 | 17 | 18 | def maybe_zero_3(param, ignore_status=False, name=None): 19 | from deepspeed import zero 20 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 21 | if hasattr(param, "ds_id"): 22 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 23 | if not ignore_status: 24 | print(name, 'no ignore status') 25 | with zero.GatheredParameters([param]): 26 | param = param.data.detach().cpu().clone() 27 | else: 28 | param = param.detach().cpu().clone() 29 | return param 30 | 31 | 32 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 33 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 34 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 35 | return to_return 36 | 37 | 38 | def split_to_even_chunks(indices, lengths, num_chunks): 39 | """ 40 | Split a list of indices into `chunks` chunks of roughly equal lengths. 41 | """ 42 | 43 | if len(indices) % num_chunks != 0: 44 | return [indices[i::num_chunks] for i in range(num_chunks)] 45 | 46 | num_indices_per_chunk = len(indices) // num_chunks 47 | 48 | chunks = [[] for _ in range(num_chunks)] 49 | chunks_lengths = [0 for _ in range(num_chunks)] 50 | for index in indices: 51 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 52 | chunks[shortest_chunk].append(index) 53 | chunks_lengths[shortest_chunk] += lengths[index] 54 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 55 | chunks_lengths[shortest_chunk] = float("inf") 56 | 57 | return chunks 58 | 59 | 60 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): 61 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 62 | assert all(l != 0 for l in lengths), "Should not have zero length." 63 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): 64 | # all samples are in the same modality 65 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) 66 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) 67 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) 68 | 69 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] 70 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] 71 | megabatch_size = world_size * batch_size 72 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] 73 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] 74 | 75 | last_mm = mm_megabatches[-1] 76 | last_lang = lang_megabatches[-1] 77 | additional_batch = last_mm + last_lang 78 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] 79 | megabatch_indices = torch.randperm(len(megabatches), generator=generator) 80 | megabatches = [megabatches[i] for i in megabatch_indices] 81 | 82 | if len(additional_batch) > 0: 83 | megabatches.append(sorted(additional_batch)) 84 | 85 | return [i for megabatch in megabatches for i in megabatch] 86 | 87 | 88 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 89 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 90 | indices = torch.randperm(len(lengths), generator=generator) 91 | megabatch_size = world_size * batch_size 92 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 93 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] 94 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] 95 | 96 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 97 | 98 | 99 | class LengthGroupedSampler(Sampler): 100 | r""" 101 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 102 | keeping a bit of randomness. 103 | """ 104 | 105 | def __init__( 106 | self, 107 | batch_size: int, 108 | world_size: int, 109 | lengths: Optional[List[int]] = None, 110 | generator=None, 111 | group_by_modality: bool = False, 112 | ): 113 | if lengths is None: 114 | raise ValueError("Lengths must be provided.") 115 | 116 | self.batch_size = batch_size 117 | self.world_size = world_size 118 | self.lengths = lengths 119 | self.generator = generator 120 | self.group_by_modality = group_by_modality 121 | 122 | def __len__(self): 123 | return len(self.lengths) 124 | 125 | def __iter__(self): 126 | if self.group_by_modality: 127 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 128 | else: 129 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 130 | return iter(indices) 131 | 132 | 133 | class LLaVATrainer(Trainer): 134 | 135 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 136 | if self.train_dataset is None or not has_length(self.train_dataset): 137 | return None 138 | 139 | if self.args.group_by_modality_length: 140 | lengths = self.train_dataset.modality_lengths 141 | return LengthGroupedSampler( 142 | self.args.train_batch_size, 143 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 144 | lengths=lengths, 145 | group_by_modality=True, 146 | ) 147 | else: 148 | return super()._get_train_sampler() 149 | 150 | def create_optimizer(self): 151 | """ 152 | Setup the optimizer. 153 | 154 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 155 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 156 | """ 157 | if is_sagemaker_mp_enabled(): 158 | return super().create_optimizer() 159 | if self.sharded_ddp == ShardedDDPOption.SIMPLE: 160 | return super().create_optimizer() 161 | 162 | opt_model = self.model 163 | 164 | if self.optimizer is None: 165 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 166 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 167 | if self.args.mm_projector_lr is not None: 168 | projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] 169 | optimizer_grouped_parameters = [ 170 | { 171 | "params": [ 172 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) 173 | ], 174 | "weight_decay": self.args.weight_decay, 175 | }, 176 | { 177 | "params": [ 178 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) 179 | ], 180 | "weight_decay": 0.0, 181 | }, 182 | { 183 | "params": [ 184 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) 185 | ], 186 | "weight_decay": self.args.weight_decay, 187 | "lr": self.args.mm_projector_lr, 188 | }, 189 | { 190 | "params": [ 191 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) 192 | ], 193 | "weight_decay": 0.0, 194 | "lr": self.args.mm_projector_lr, 195 | }, 196 | ] 197 | else: 198 | optimizer_grouped_parameters = [ 199 | { 200 | "params": [ 201 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) 202 | ], 203 | "weight_decay": self.args.weight_decay, 204 | }, 205 | { 206 | "params": [ 207 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) 208 | ], 209 | "weight_decay": 0.0, 210 | }, 211 | ] 212 | 213 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 214 | 215 | if self.sharded_ddp == ShardedDDPOption.SIMPLE: 216 | self.optimizer = OSS( 217 | params=optimizer_grouped_parameters, 218 | optim=optimizer_cls, 219 | **optimizer_kwargs, 220 | ) 221 | else: 222 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 223 | if optimizer_cls.__name__ == "Adam8bit": 224 | import bitsandbytes 225 | 226 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 227 | 228 | skipped = 0 229 | for module in opt_model.modules(): 230 | if isinstance(module, nn.Embedding): 231 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) 232 | logger.info(f"skipped {module}: {skipped/2**20}M params") 233 | manager.register_module_override(module, "weight", {"optim_bits": 32}) 234 | logger.debug(f"bitsandbytes: will optimize {module} in fp32") 235 | logger.info(f"skipped: {skipped/2**20}M params") 236 | 237 | return self.optimizer 238 | 239 | def _save_checkpoint(self, model, trial, metrics=None): 240 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 241 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 242 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 243 | 244 | run_dir = self._get_output_dir(trial=trial) 245 | output_dir = os.path.join(run_dir, checkpoint_folder) 246 | 247 | # Only save Adapter 248 | keys_to_match = ['mm_projector', 'vision_resampler'] 249 | if getattr(self.args, "use_im_start_end", False): 250 | keys_to_match.extend(['embed_tokens', 'embed_in']) 251 | 252 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 253 | 254 | if self.args.local_rank == 0 or self.args.local_rank == -1: 255 | self.model.config.save_pretrained(output_dir) 256 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 257 | else: 258 | super(LLaVATrainer, self)._save_checkpoint(model, trial, metrics) 259 | 260 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 261 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 262 | pass 263 | else: 264 | super(LLaVATrainer, self)._save(output_dir, state_dict) 265 | -------------------------------------------------------------------------------- /llava/train/train_mem.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from llava.train.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from llava.train.llama_xformers_attn_monkey_patch import ( 5 | replace_llama_attn_with_xformers_attn, 6 | ) 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | from llava.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /llava/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | 7 | import requests 8 | 9 | from llava.constants import LOGDIR 10 | 11 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 12 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 13 | 14 | handler = None 15 | 16 | 17 | def build_logger(logger_name, logger_filename): 18 | global handler 19 | 20 | formatter = logging.Formatter( 21 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 22 | datefmt="%Y-%m-%d %H:%M:%S", 23 | ) 24 | 25 | # Set the format of root handlers 26 | if not logging.getLogger().handlers: 27 | logging.basicConfig(level=logging.INFO) 28 | logging.getLogger().handlers[0].setFormatter(formatter) 29 | 30 | # Redirect stdout and stderr to loggers 31 | stdout_logger = logging.getLogger("stdout") 32 | stdout_logger.setLevel(logging.INFO) 33 | sl = StreamToLogger(stdout_logger, logging.INFO) 34 | sys.stdout = sl 35 | 36 | stderr_logger = logging.getLogger("stderr") 37 | stderr_logger.setLevel(logging.ERROR) 38 | sl = StreamToLogger(stderr_logger, logging.ERROR) 39 | sys.stderr = sl 40 | 41 | # Get logger 42 | logger = logging.getLogger(logger_name) 43 | logger.setLevel(logging.INFO) 44 | 45 | # Add a file handler for all loggers 46 | if handler is None: 47 | os.makedirs(LOGDIR, exist_ok=True) 48 | filename = os.path.join(LOGDIR, logger_filename) 49 | handler = logging.handlers.TimedRotatingFileHandler( 50 | filename, when='D', utc=True) 51 | handler.setFormatter(formatter) 52 | 53 | for name, item in logging.root.manager.loggerDict.items(): 54 | if isinstance(item, logging.Logger): 55 | item.addHandler(handler) 56 | 57 | return logger 58 | 59 | 60 | class StreamToLogger(object): 61 | """ 62 | Fake file-like stream object that redirects writes to a logger instance. 63 | """ 64 | def __init__(self, logger, log_level=logging.INFO): 65 | self.terminal = sys.stdout 66 | self.logger = logger 67 | self.log_level = log_level 68 | self.linebuf = '' 69 | 70 | def __getattr__(self, attr): 71 | return getattr(self.terminal, attr) 72 | 73 | def write(self, buf): 74 | temp_linebuf = self.linebuf + buf 75 | self.linebuf = '' 76 | for line in temp_linebuf.splitlines(True): 77 | # From the io.TextIOWrapper docs: 78 | # On output, if newline is None, any '\n' characters written 79 | # are translated to the system default line separator. 80 | # By default sys.stdout.write() expects '\n' newlines and then 81 | # translates them so this is still cross platform. 82 | if line[-1] == '\n': 83 | self.logger.log(self.log_level, line.rstrip()) 84 | else: 85 | self.linebuf += line 86 | 87 | def flush(self): 88 | if self.linebuf != '': 89 | self.logger.log(self.log_level, self.linebuf.rstrip()) 90 | self.linebuf = '' 91 | 92 | 93 | def disable_torch_init(): 94 | """ 95 | Disable the redundant torch default initialization to accelerate model creation. 96 | """ 97 | import torch 98 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 99 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 100 | 101 | 102 | def violates_moderation(text): 103 | """ 104 | Check whether the text violates OpenAI moderation API. 105 | """ 106 | url = "https://api.openai.com/v1/moderations" 107 | headers = {"Content-Type": "application/json", 108 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 109 | text = text.replace("\n", "") 110 | data = "{" + '"input": ' + f'"{text}"' + "}" 111 | data = data.encode("utf-8") 112 | try: 113 | ret = requests.post(url, headers=headers, data=data, timeout=5) 114 | flagged = ret.json()["results"][0]["flagged"] 115 | except requests.exceptions.RequestException as e: 116 | flagged = False 117 | except KeyError as e: 118 | flagged = False 119 | 120 | return flagged 121 | 122 | 123 | def pretty_print_semaphore(semaphore): 124 | if semaphore is None: 125 | return "None" 126 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 127 | -------------------------------------------------------------------------------- /qwen_generation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Generation support.""" 7 | 8 | from typing import Tuple, List, Union, Iterable 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from transformers import PreTrainedTokenizer 14 | from transformers import logging 15 | from transformers.generation import LogitsProcessor 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | # Types. 20 | HistoryType = List[Tuple[str, str]] 21 | TokensType = List[int] 22 | BatchTokensType = List[List[int]] 23 | 24 | 25 | def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType: 26 | for tokens in batch: 27 | context_length = len(tokens) 28 | if context_length < seq_length: 29 | tokens.extend([pad_id] * (seq_length - context_length)) 30 | return batch 31 | 32 | 33 | def get_ltor_masks_and_position_ids( 34 | data, 35 | eod_token, 36 | reset_position_ids, 37 | reset_attention_mask, 38 | eod_mask_loss, 39 | ): 40 | """Build masks and position id for left to right model.""" 41 | 42 | # Extract batch size and sequence length. 43 | micro_batch_size, seq_length = data.size() 44 | 45 | # Attention mask (lower triangular). 46 | if reset_attention_mask: 47 | att_mask_batch = micro_batch_size 48 | else: 49 | att_mask_batch = 1 50 | attention_mask = torch.tril( 51 | torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) 52 | ).view(att_mask_batch, 1, seq_length, seq_length) 53 | 54 | # Loss mask. 55 | loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) 56 | if eod_mask_loss: 57 | loss_mask[data == eod_token] = 0.0 58 | 59 | # Position ids. 60 | position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) 61 | position_ids = position_ids.unsqueeze(0).expand_as(data) 62 | # We need to clone as the ids will be modifed based on batch index. 63 | if reset_position_ids: 64 | position_ids = position_ids.clone() 65 | 66 | if reset_position_ids or reset_attention_mask: 67 | # Loop through the batches: 68 | for b in range(micro_batch_size): 69 | 70 | # Find indecies where EOD token is. 71 | eod_index = position_ids[b, data[b] == eod_token] 72 | # Detach indecies from positions if going to modify positions. 73 | if reset_position_ids: 74 | eod_index = eod_index.clone() 75 | 76 | # Loop through EOD indecies: 77 | prev_index = 0 78 | for j in range(eod_index.size()[0]): 79 | i = eod_index[j] 80 | # Mask attention loss. 81 | if reset_attention_mask: 82 | attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 83 | # Reset positions. 84 | if reset_position_ids: 85 | position_ids[b, (i + 1) :] -= i + 1 - prev_index 86 | prev_index = i + 1 87 | 88 | # Convert attention mask to binary: 89 | attention_mask = attention_mask < 0.5 90 | 91 | return attention_mask, loss_mask, position_ids 92 | 93 | 94 | def get_batch(context_tokens: torch.LongTensor, eod_id: int): 95 | """Generate batch from context tokens.""" 96 | # Move to GPU. 97 | tokens = context_tokens.contiguous().to(context_tokens.device) 98 | # Get the attention mask and postition ids. 99 | attention_mask, _, position_ids = get_ltor_masks_and_position_ids( 100 | tokens, 101 | eod_id, 102 | reset_position_ids=False, 103 | reset_attention_mask=False, 104 | eod_mask_loss=False, 105 | ) 106 | return tokens, attention_mask, position_ids 107 | 108 | 109 | def get_stop_words_ids(chat_format, tokenizer): 110 | if chat_format == "raw": 111 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]] 112 | elif chat_format == "chatml": 113 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] 114 | else: 115 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 116 | return stop_words_ids 117 | 118 | 119 | def make_context( 120 | tokenizer: PreTrainedTokenizer, 121 | query: str, 122 | history: List[Tuple[str, str]] = None, 123 | system: str = "", 124 | max_window_size: int = 6144, 125 | chat_format: str = "chatml", 126 | ): 127 | if history is None: 128 | history = [] 129 | 130 | if chat_format == "chatml": 131 | im_start, im_end = "<|im_start|>", "<|im_end|>" 132 | im_start_tokens = [tokenizer.im_start_id] 133 | im_end_tokens = [tokenizer.im_end_id] 134 | nl_tokens = tokenizer.encode("\n") 135 | 136 | def _tokenize_str(role, content): 137 | return f"{role}\n{content}", tokenizer.encode( 138 | role, allowed_special=set(tokenizer.IMAGE_ST) 139 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST)) 140 | 141 | system_text, system_tokens_part = _tokenize_str("system", system) 142 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens 143 | 144 | raw_text = "" 145 | context_tokens = [] 146 | 147 | for turn_query, turn_response in reversed(history): 148 | query_text, query_tokens_part = _tokenize_str("user", turn_query) 149 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens 150 | if turn_response is not None: 151 | response_text, response_tokens_part = _tokenize_str( 152 | "assistant", turn_response 153 | ) 154 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens 155 | 156 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens 157 | prev_chat = ( 158 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" 159 | ) 160 | else: 161 | next_context_tokens = nl_tokens + query_tokens + nl_tokens 162 | prev_chat = f"\n{im_start}{query_text}{im_end}\n" 163 | 164 | current_context_size = ( 165 | len(system_tokens) + len(next_context_tokens) + len(context_tokens) 166 | ) 167 | if current_context_size < max_window_size: 168 | context_tokens = next_context_tokens + context_tokens 169 | raw_text = prev_chat + raw_text 170 | else: 171 | break 172 | 173 | context_tokens = system_tokens + context_tokens 174 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text 175 | context_tokens += ( 176 | nl_tokens 177 | + im_start_tokens 178 | + _tokenize_str("user", query)[1] 179 | + im_end_tokens 180 | + nl_tokens 181 | + im_start_tokens 182 | + tokenizer.encode("assistant") 183 | + nl_tokens 184 | ) 185 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" 186 | 187 | elif chat_format == "raw": 188 | raw_text = query 189 | context_tokens = tokenizer.encode(raw_text) 190 | else: 191 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 192 | 193 | return raw_text, context_tokens 194 | 195 | 196 | def _decode_default( 197 | tokens: List[int], 198 | *, 199 | stop_words: List[str], 200 | eod_words: List[str], 201 | tokenizer: PreTrainedTokenizer, 202 | raw_text_len: int, 203 | verbose: bool = False, 204 | return_end_reason: bool = False, 205 | errors: str='replace', 206 | ): 207 | trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:] 208 | if verbose: 209 | print("\nRaw Generate: ", trim_decode_tokens) 210 | 211 | end_reason = f"Gen length {len(tokens)}" 212 | for stop_word in stop_words: 213 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() 214 | for eod_word in eod_words: 215 | if eod_word in trim_decode_tokens: 216 | end_reason = f"Gen {eod_word!r}" 217 | trim_decode_tokens = trim_decode_tokens.split(eod_word)[0] 218 | trim_decode_tokens = trim_decode_tokens.strip() 219 | if verbose: 220 | print("\nEnd Reason:", end_reason) 221 | print("\nGenerate: ", trim_decode_tokens) 222 | 223 | if return_end_reason: 224 | return trim_decode_tokens, end_reason 225 | else: 226 | return trim_decode_tokens 227 | 228 | 229 | def _decode_chatml( 230 | tokens: List[int], 231 | *, 232 | stop_words: List[str], 233 | eod_token_ids: List[int], 234 | tokenizer: PreTrainedTokenizer, 235 | raw_text_len: int, 236 | context_length: int, 237 | verbose: bool = False, 238 | return_end_reason: bool = False, 239 | errors: str='replace' 240 | ): 241 | end_reason = f"Gen length {len(tokens)}" 242 | eod_token_idx = context_length 243 | for eod_token_idx in range(context_length, len(tokens)): 244 | if tokens[eod_token_idx] in eod_token_ids: 245 | end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" 246 | break 247 | 248 | trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:] 249 | if verbose: 250 | print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:]) 251 | print("\nRaw Generate:", trim_decode_tokens) 252 | print("\nEnd Reason:", end_reason) 253 | for stop_word in stop_words: 254 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() 255 | trim_decode_tokens = trim_decode_tokens.strip() 256 | if verbose: 257 | print("\nGenerate:", trim_decode_tokens) 258 | 259 | if return_end_reason: 260 | return trim_decode_tokens, end_reason 261 | else: 262 | return trim_decode_tokens 263 | 264 | 265 | def decode_tokens( 266 | tokens: Union[torch.LongTensor, TokensType], 267 | tokenizer: PreTrainedTokenizer, 268 | raw_text_len: int, 269 | context_length: int, 270 | chat_format: str, 271 | verbose: bool = False, 272 | return_end_reason: bool = False, 273 | errors: str="replace", 274 | ) -> str: 275 | if torch.is_tensor(tokens): 276 | tokens = tokens.cpu().numpy().tolist() 277 | 278 | if chat_format == "chatml": 279 | return _decode_chatml( 280 | tokens, 281 | stop_words=[], 282 | eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id], 283 | tokenizer=tokenizer, 284 | raw_text_len=raw_text_len, 285 | context_length=context_length, 286 | verbose=verbose, 287 | return_end_reason=return_end_reason, 288 | errors=errors, 289 | ) 290 | elif chat_format == "raw": 291 | return _decode_default( 292 | tokens, 293 | stop_words=["<|endoftext|>"], 294 | eod_words=["<|endoftext|>"], 295 | tokenizer=tokenizer, 296 | raw_text_len=raw_text_len, 297 | verbose=verbose, 298 | return_end_reason=return_end_reason, 299 | errors=errors, 300 | ) 301 | else: 302 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 303 | 304 | 305 | class StopWordsLogitsProcessor(LogitsProcessor): 306 | """ 307 | :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. 308 | 309 | Args: 310 | stop_words_ids (:obj:`List[List[int]]`): 311 | List of list of token ids of stop ids. In order to get the tokens of the words 312 | that should not appear in the generated text, use :obj:`tokenizer(bad_word, 313 | add_prefix_space=True).input_ids`. 314 | eos_token_id (:obj:`int`): 315 | The id of the `end-of-sequence` token. 316 | """ 317 | 318 | def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): 319 | 320 | if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: 321 | raise ValueError( 322 | f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}." 323 | ) 324 | if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): 325 | raise ValueError( 326 | f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}." 327 | ) 328 | if any( 329 | any( 330 | (not isinstance(token_id, (int, np.integer)) or token_id < 0) 331 | for token_id in stop_word_ids 332 | ) 333 | for stop_word_ids in stop_words_ids 334 | ): 335 | raise ValueError( 336 | f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." 337 | ) 338 | 339 | self.stop_words_ids = list( 340 | filter( 341 | lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids 342 | ) 343 | ) 344 | self.eos_token_id = eos_token_id 345 | for stop_token_seq in self.stop_words_ids: 346 | assert ( 347 | len(stop_token_seq) > 0 348 | ), "Stop words token sequences {} cannot have an empty list".format( 349 | stop_words_ids 350 | ) 351 | 352 | def __call__( 353 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 354 | ) -> torch.FloatTensor: 355 | stopped_samples = self._calc_stopped_samples(input_ids) 356 | for i, should_stop in enumerate(stopped_samples): 357 | if should_stop: 358 | scores[i, self.eos_token_id] = float(2**15) 359 | return scores 360 | 361 | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: 362 | if len(tokens) == 0: 363 | # if bad word tokens is just one token always ban it 364 | return True 365 | elif len(tokens) > len(prev_tokens): 366 | # if bad word tokens are longer then prev input_ids they can't be equal 367 | return False 368 | elif prev_tokens[-len(tokens) :].tolist() == tokens: 369 | # if tokens match 370 | return True 371 | else: 372 | return False 373 | 374 | def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]: 375 | stopped_samples = [] 376 | for prev_input_ids_slice in prev_input_ids: 377 | match = False 378 | for stop_token_seq in self.stop_words_ids: 379 | if self._tokens_match(prev_input_ids_slice, stop_token_seq): 380 | # if tokens do not match continue 381 | match = True 382 | break 383 | stopped_samples.append(match) 384 | 385 | return stopped_samples 386 | 387 | 388 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): 389 | """This function has been mostly taken from huggingface conversational 390 | ai code at 391 | https://medium.com/huggingface/how-to-build-a-state-of-the-art- 392 | conversational-ai-with-transfer-learning-2d818ac26313""" 393 | 394 | if top_k > 0: 395 | # Remove all tokens with a probability less than the 396 | # last token of the top-k 397 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 398 | logits[indices_to_remove] = filter_value 399 | 400 | if top_p > 0.0: 401 | # Cconvert to 1D 402 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) 403 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 404 | 405 | # Remove tokens with cumulative probability above the threshold 406 | sorted_indices_to_remove = cumulative_probs > top_p 407 | # Shift the indices to the right to keep also the first token 408 | # above the threshold 409 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 410 | sorted_indices_to_remove[..., 0] = 0 411 | for i in range(sorted_indices.size(0)): 412 | indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] 413 | logits[i][indices_to_remove] = filter_value 414 | 415 | return logits 416 | 417 | 418 | def switch(val1, val2, boolean): 419 | boolean = boolean.type_as(val1) 420 | return (1 - boolean) * val1 + boolean * val2 421 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.31.0 2 | torch==2.1.2 3 | torchvision==0.16.2 4 | tokenizers==0.13.1 5 | sentencepiece 6 | shortuuid 7 | accelerate==0.21.0 8 | peft 9 | bitsandbytes 10 | numpy==1.23 11 | scikit-learn==1.2.2 12 | requests 13 | httpx==0.24.0 14 | uvicorn 15 | fastapi 16 | einops==0.6.1 17 | einops-exts==0.0.4 18 | timm==0.6.13 19 | rouge 20 | matplotlib 21 | tiktoken 22 | transformers_stream_generator -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | python3 eval_ppl.py\ 2 | --model-path ./models/llava-v1.5-7b \ 3 | --data-path ./playground/data/mm-vet/mm-vet.json \ 4 | --image-path ./playground/data/mm-vet/images \ 5 | --eval-samples 218 \ 6 | --method "elastic" \ 7 | --ratio 0.2 \ 8 | --exp-name "llava-7b-ppl-mmvet" --------------------------------------------------------------------------------