├── README.md ├── Small Language Model Meets with Reinforced Vision Vocabulary.pdf ├── Vary-master ├── pyproject.toml ├── pyvenv.cfg ├── vary │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ ├── caption_opt.py │ │ └── conversation_dataset_qwen.py │ ├── demo │ │ ├── run_opt.py │ │ └── run_qwen_vary.py │ ├── model │ │ ├── __init__.py │ │ ├── llm │ │ │ ├── opt │ │ │ │ └── modeling_opt.py │ │ │ └── qwen │ │ │ │ ├── configuration_qwen.py │ │ │ │ ├── modeling_qwen.py │ │ │ │ ├── qwen_generation_utils.py │ │ │ │ └── tokenization_qwen.py │ │ ├── plug │ │ │ ├── blip_process.py │ │ │ └── transforms.py │ │ ├── vary_opt.py │ │ ├── vary_qwen_vary.py │ │ ├── vary_toy_qwen1_8.py │ │ └── vision_encoder │ │ │ ├── __init__.py │ │ │ └── sam.py │ ├── train │ │ ├── train_flash_attn.py │ │ ├── train_lora.py │ │ ├── train_lora_flash_attn.py │ │ ├── train_opt.py │ │ ├── train_qwen_vary.py │ │ ├── trainer.py │ │ └── trainer_vit_fixlr.py │ └── utils │ │ ├── arguments.py │ │ ├── constants.py │ │ ├── conversation.py │ │ ├── llama_flash_attn_monkey_patch.py │ │ └── utils.py └── zero_config │ └── zero2.json └── assets └── vary-toy-logo.jpg /README.md: -------------------------------------------------------------------------------- 1 |

Small Language Model Meets with Reinforced Vision Vocabulary

2 | 3 | 4 | 5 | 6 | 7 | Ucas-HaoranWei%2FVary-toy | Trendshift 8 | 9 | [Haoran Wei*](https://scholar.google.com/citations?user=J4naK0MAAAAJ&hl=en), Lingyu Kong*, Jinyue Chen, Liang Zhao, [Zheng Ge](https://joker316701882.github.io/), [En Yu](https://scholar.google.com.hk/citations?user=rWCQMNgAAAAJ&hl=zh-CN&oi=sra), [Jianjian Sun](https://scholar.google.com/citations?user=MVZrGkYAAAAJ&hl=en), Chunrui Han, [Xiangyu Zhang](https://scholar.google.com/citations?user=yuB-cfoAAAAJ&hl=en) 10 | 11 |

12 | 13 |

14 | 15 |

16 | Two-Stream Hypothesis for LVLMs 17 |

18 | 19 | 20 | ## Release 21 | - [2024/9/03] 🔥🔥🔥 We release a very strong and comprehensive OCR model [GOT-OCR2.0](https://github.com/Ucas-HaoranWei/GOT-OCR2.0). 22 | - [2024/7/21] 🎉🎉🎉 OneChart is accepted by ACM'MM 2024 **Oral**! (3.97%) 23 | - [2024/7/2] 🔥🔥🔥 Vary is accepted by ECCV2024. To thank everyone for their attention, I will release a model that performs on par with the Vary-document soon. 24 | - [2024/5/27] 🔥🔥🔥 We present a document understanding benchmark in [Fox](https://github.com/ucaslcl/Fox) . 25 | - [2024/5/24] 🔥🔥🔥 We propose a multi-page document understanding work -- [Fox](https://arxiv.org/abs/2405.14295), which supports 8-page pdf-image input !!! 26 | - [2024/4/21] 🔥🔥🔥 For OneChart, we have released the web demo in [Project Page](https://onechartt.github.io/). Have fun!! 27 | - [2024/4/21] 🔥🔥🔥 We present a Vary-tiny LAVIS codebase (for training from scratch) and the Vary-600k dataset (300K English and 300K Chinese pages) [here](https://github.com/Ucas-HaoranWei/Vary-tiny-600k) !!! 28 | - [2024/4/15]🔥🔥🔥We release a chart parsing model OneChart [here](https://github.com/LingyvKong/OneChart). 29 | - [2024/4/12]🔥🔥🔥We will release a chart parsing model based on Vary-tiny next week. The model supports both English and Chinese charts. 30 | - [2024/3/16]🔥🔥🔥I found many friends very interested in Vary-tiny(OPT-125M), so I opened source it [here](https://huggingface.co/HaoranWei/Vary-tiny-opt125M/tree/main), a PDF-dense OCR and object detection version. 31 | - [2024/1/23] 🔥Eval codes will be available soon. 32 | - [2024/1/23] 🔥🔥🔥You only need a single 1080Ti to experience all features of current LVLMs. 33 | 34 | 35 | 36 | 37 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE) 38 | [![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg)](https://github.com/tatsu-lab/stanford_alpaca/blob/main/DATA_LICENSE) 39 | **Usage and License Notices**: The data, code, and checkpoint are intended and licensed for research use only. They are also restricted to use that follow the license agreement of LLaMA, Vicuna, GPT-4, Qwen, and LLaVA. 40 | 41 | 42 | ## Contents 43 | - [Install](#install) 44 | - [Vary-toy Weights](#vary-weights) 45 | - [Demo](#Demo) 46 | - [Train](#train) 47 | 48 | ## Note 49 | If you have built the original [Vary](https://github.com/Ucas-HaoranWei/Vary), please rebuild this repo !!! 50 | 51 | ## Install 52 | 53 | 1. Clone this repository and navigate to the Vary folder 54 | ```bash 55 | git clone https://github.com/Ucas-HaoranWei/Vary-toy.git 56 | cd /path/to/vary-toy 57 | ``` 58 | 2. Install Package 59 | ```Shell 60 | conda create -n vary python=3.10 -y 61 | conda activate vary 62 | pip install e . 63 | ``` 64 | 65 | 3. Install Flash-Attention 66 | ``` 67 | pip install ninja 68 | pip install flash-attn --no-build-isolation 69 | ``` 70 | 71 | ## Vary Weights 72 | - Download the Vary-toy weights [here](https://huggingface.co/Haoran-megvii/Vary-toy). 73 | - Download the CLIP-VIT-L [here](https://huggingface.co/openai/clip-vit-large-patch14/). 74 | 75 | 76 | 77 | ## Demo 78 | 1. Update the CLIP-VIT path in the codes (/cache/vit-large-patch14/) to your path. 79 | 80 | 2. 81 | ```Shell 82 | python vary/demo/run_qwen_vary.py --model-name /vary/model/path/ --image-file /an/image/file.png 83 | ``` 84 | ## Train 85 | ```Shell 86 | deepspeed Vary/train/train_qwen_vary.py --deepspeed /Vary/zero_config/zero2.json 87 | --model_name_or_path /Vary-toy/path/ 88 | --vision_tower /vit-large-patch14/path/ 89 | --freeze_vision_tower True 90 | --freeze_lm_model False 91 | --vision_select_layer -2 92 | --use_im_start_end True 93 | --bf16 True 94 | --per_device_eval_batch_size 4 95 | --gradient_accumulation_steps 1 96 | --evaluation_strategy "no" 97 | --save_strategy "steps" 98 | --save_steps 5000 99 | --save_total_limit 1 100 | --weight_decay 0. 101 | --warmup_ratio 0.03 102 | --lr_scheduler_type "cosine" 103 | --logging_steps 1 --tf32 True 104 | --model_max_length 4096 105 | --gradient_checkpointing True 106 | --dataloader_num_workers 4 107 | --report_to none 108 | --per_device_train_batch_size 4 109 | --num_train_epochs 1 110 | --learning_rate 5e-5 111 | --datasets data_name1+data_name2+data_name3 112 | --output_dir /path/to/output/ 113 | ``` 114 | We encourage you to extract the new vision vocabulary weights for your new base language model !!! 115 | 116 | ## Contact 117 | If you have any questions about the code or the paper, please email (`weihaoran18@mails.ucas.ac.cn`). 118 | 119 | ## Discussion 120 | Vary-toy is not a toy, and we have designed two excellent models based on it, one is Vary-document (specifically for document/pdf processing), and the other is Vary-plot for chart analysis. You can see their amazing performance here [Vary-family](https://github.com/Ucas-HaoranWei/Vary-family). 121 | 122 | ## Citation 123 | If you find our work useful in your research, please consider citing Vary: 124 | ```bibtex 125 | @article{wei2023vary, 126 | title={Vary: Scaling up the Vision Vocabulary for Large Vision-Language Models}, 127 | author={Wei, Haoran and Kong, Lingyu and Chen, Jinyue and Zhao, Liang and Ge, Zheng and Yang, Jinrong and Sun, Jianjian and Han, Chunrui and Zhang, Xiangyu}, 128 | journal={arXiv preprint arXiv:2312.06109}, 129 | year={2023} 130 | } 131 | 132 | @article{wei2024small, 133 | title={Small Language Model Meets with Reinforced Vision Vocabulary}, 134 | author={Wei, Haoran and Kong, Lingyu and Chen, Jinyue and Zhao, Liang and Ge, Zheng and Yu, En and Sun, Jianjian and Han, Chunrui and Zhang, Xiangyu}, 135 | journal={arXiv preprint arXiv:2401.12503}, 136 | year={2024} 137 | } 138 | ``` 139 | 140 | -------------------------------------------------------------------------------- /Small Language Model Meets with Reinforced Vision Vocabulary.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ucas-HaoranWei/Vary-toy/c6e405977aa6e88d4807a6b79f394299623fc6f3/Small Language Model Meets with Reinforced Vision Vocabulary.pdf -------------------------------------------------------------------------------- /Vary-master/pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "vary" 7 | version = "0.1.0" 8 | description = "Towards GPT-4 like large language and visual assistant." 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "einops", "markdown2[all]", "numpy", 17 | "requests", "sentencepiece", "tokenizers>=0.12.1", 18 | "torch", "torchvision", "wandb", 19 | "shortuuid", "httpx==0.24.0", 20 | "deepspeed==0.12.3", 21 | "peft==0.4.0", 22 | "albumentations ", 23 | "opencv-python", 24 | "tiktoken", 25 | "accelerate==0.24.1", 26 | "transformers==4.32.1", 27 | "bitsandbytes==0.41.0", 28 | "scikit-learn==1.2.2", 29 | "sentencepiece==0.1.99", 30 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", 31 | "gradio_client==0.2.9" 32 | ] 33 | 34 | [tool.setuptools.packages.find] 35 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 36 | 37 | [tool.wheel] 38 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 39 | -------------------------------------------------------------------------------- /Vary-master/pyvenv.cfg: -------------------------------------------------------------------------------- 1 | home = /usr/bin 2 | implementation = CPython 3 | version_info = 3.8.10.final.0 4 | virtualenv = 20.16.7 5 | include-system-site-packages = true 6 | base-prefix = /usr 7 | base-exec-prefix = /usr 8 | base-executable = /usr/bin/python3 9 | -------------------------------------------------------------------------------- /Vary-master/vary/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ucas-HaoranWei/Vary-toy/c6e405977aa6e88d4807a6b79f394299623fc6f3/Vary-master/vary/__init__.py -------------------------------------------------------------------------------- /Vary-master/vary/data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import transformers 4 | from dataclasses import dataclass, field 5 | 6 | from vary.utils.constants import * 7 | 8 | 9 | @dataclass 10 | class DataCollatorForSupervisedDataset(object): 11 | tokenizer: transformers.PreTrainedTokenizer 12 | 13 | def __call__(self, instances): 14 | 15 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 16 | images = [torch.stack(instance['image']) for instance in instances] 17 | 18 | 19 | images_high = [torch.stack(instance['image_high']) for instance in instances] 20 | 21 | images = list(zip(images, images_high)) 22 | 23 | 24 | input_ids = torch.nn.utils.rnn.pad_sequence( 25 | input_ids, 26 | batch_first=True, 27 | padding_value=self.tokenizer.pad_token_id) 28 | 29 | labels = torch.nn.utils.rnn.pad_sequence( 30 | labels, 31 | batch_first=True, 32 | padding_value=IGNORE_INDEX) 33 | 34 | batch = dict( 35 | input_ids=input_ids, 36 | labels=labels, 37 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 38 | images=images, 39 | ) 40 | return batch 41 | 42 | 43 | def make_supervised_data_module(interleave, with_box, tokenizer, data_args): 44 | 45 | if data_args.conversation_version == 'mpt': 46 | from vary.data.conversation_dataset_qwen import ConversationDataset 47 | dataset_cls = ConversationDataset 48 | elif data_args.conversation_version == 'opt': 49 | from vary.data.caption_opt import CaptionDataset 50 | dataset_cls = CaptionDataset 51 | 52 | train_dataset = dataset_cls( 53 | tokenizer=tokenizer, 54 | datasets=data_args.datasets, 55 | multimodal_cfg=dict( 56 | sep_image_conv_front=data_args.sep_image_conv_front, 57 | image_token_len=data_args.image_token_len, 58 | image_aspect_ratio=data_args.image_aspect_ratio, 59 | use_im_start_end=data_args.use_im_start_end, 60 | image_processor=data_args.image_processor, 61 | image_processor_high = data_args.image_processor_high, 62 | box_limit=data_args.box_limit, 63 | ) 64 | ) 65 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 66 | return dict(train_dataset=train_dataset, 67 | eval_dataset=None, 68 | data_collator=data_collator) -------------------------------------------------------------------------------- /Vary-master/vary/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import copy 4 | import json 5 | import logging 6 | import torch 7 | import transformers 8 | from typing import List, Optional, Tuple, Union, Dict, Sequence 9 | from torch.utils.data import Dataset 10 | from PIL import Image, ImageFile 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | from vary.utils.constants import * 13 | 14 | 15 | 16 | class BaseDataset(Dataset): 17 | def __init__( 18 | self, 19 | datasets: str, 20 | tokenizer: transformers.PreTrainedTokenizer, 21 | multimodal_cfg: dict 22 | ): 23 | super(BaseDataset, self).__init__() 24 | self.tokenizer = tokenizer 25 | self.multimodal_cfg = multimodal_cfg 26 | 27 | logging.warning(f"Using {multimodal_cfg['image_token_len']} tokens for representing image") 28 | 29 | def image_processor(self, image): 30 | processor = self.multimodal_cfg['image_processor'] # the first processor, usually is the clip pretrained model (vit) 31 | processor_high = self.multimodal_cfg['image_processor_high'] # the second processor, usually is the designed image encoder (sam/swin/cnn) 32 | image_high = image.copy() 33 | # TODO the 'keep', 'padding' only used for the first processor 34 | if self.multimodal_cfg['image_aspect_ratio'] == 'keep': 35 | max_hw, min_hw = max(image.size), min(image.size) 36 | aspect_ratio = max_hw / min_hw 37 | max_len, min_len = 448, 224 38 | shortest_edge = int(min(max_len / aspect_ratio, min_len)) 39 | image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0] 40 | elif self.multimodal_cfg['image_aspect_ratio'] == 'pad': 41 | def expand2square(pil_img, background_color): 42 | width, height = pil_img.size 43 | if width == height: 44 | return pil_img 45 | elif width > height: 46 | result = Image.new(pil_img.mode, (width, width), background_color) 47 | result.paste(pil_img) # for simpler box processing 48 | return result 49 | else: 50 | result = Image.new(pil_img.mode, (height, height), background_color) 51 | result.paste(pil_img) # for simpler box processing 52 | return result 53 | image = expand2square(image, tuple(int(x*255) for x in processor.image_mean)) 54 | image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": 224})['pixel_values'][0] 55 | else: 56 | image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 57 | 58 | image_high = processor_high(image_high) 59 | 60 | return image, image_high 61 | 62 | 63 | 64 | def __len__(self): 65 | return len(self.list_data_dict) 66 | 67 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 68 | pass -------------------------------------------------------------------------------- /Vary-master/vary/data/caption_opt.py: -------------------------------------------------------------------------------- 1 | 2 | import io 3 | import os 4 | import copy 5 | import json 6 | import logging 7 | import torch 8 | import random 9 | 10 | from typing import List, Optional, Tuple, Union, Dict, Sequence 11 | from PIL import Image, ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | from vary.data.base_dataset import BaseDataset 15 | from vary.utils.constants import * 16 | from vary.utils import conversation as conversation_lib 17 | 18 | 19 | class CaptionDataset(BaseDataset): 20 | """Conversation format dataset stage2 fine-tuning.""" 21 | 22 | def __init__(self, datasets, tokenizer, multimodal_cfg): 23 | super(CaptionDataset, self).__init__(datasets, tokenizer, multimodal_cfg) 24 | # v0 version format conversation 25 | conversation_lib.default_conversation = conversation_lib.conv_templates["default"] 26 | logging.warning("Formatting inputs into conversation type: v0-fixed") 27 | logging.warning("Loading data...") 28 | 29 | list_data_dict = [] 30 | list_image_path = [] 31 | for name in datasets.split("+"): 32 | dataset = CONVERSATION_DATA[name] # in vary.utils 33 | 34 | data_path = dataset['annotations'] 35 | data = json.load(open(data_path, "r")) 36 | 37 | list_data_dict.extend(data) 38 | 39 | image_path = dataset['images'] 40 | list_image_path.extend([image_path] * len(data)) 41 | 42 | logging.warning(f"Data from {data_path} provide {len(data)} conversations.") 43 | 44 | assert len(list_data_dict) == len(list_image_path) 45 | logging.warning(f"{len(list_data_dict)} conversations in total.") 46 | a_new_list = list(zip(list_data_dict, list_image_path)) 47 | random.shuffle(a_new_list) 48 | list_data_dict_new, list_image_path_new = zip(*a_new_list) 49 | self.list_data_dict = list_data_dict_new 50 | self.list_image_path = list_image_path_new 51 | self.im_patch_token, self.im_start_token, self.im_end_token = tokenizer.convert_tokens_to_ids( 52 | [DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) 53 | 54 | def multimodal_processor(self, sources): 55 | for source in sources: 56 | 57 | source[0]['value'] = DEFAULT_IMAGE_TOKEN 58 | for sentence in source: 59 | replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len'] 60 | if self.multimodal_cfg['use_im_start_end']: 61 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN 62 | 63 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) 64 | return sources 65 | 66 | def _tokenize_fn(self, strings): 67 | """Tokenize a list of strings.""" 68 | tokenized_list = [ 69 | self.tokenizer( 70 | text, 71 | return_tensors="pt", 72 | padding="longest", 73 | max_length=self.tokenizer.model_max_length, 74 | truncation=True, 75 | ) for text in strings 76 | ] 77 | input_ids = labels = [ 78 | tokenized.input_ids[0] for tokenized in tokenized_list 79 | ] 80 | 81 | 82 | for idx, ii in enumerate(input_ids): 83 | if ii[-1] != 2: 84 | input_ids[idx][-1] = 2 85 | labels[idx][-1] = 2 86 | 87 | input_ids_lens = labels_lens = [ 88 | tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item() 89 | for tokenized in tokenized_list 90 | ] 91 | return dict( 92 | input_ids=input_ids, 93 | labels=labels, 94 | input_ids_lens=input_ids_lens, 95 | labels_lens=labels_lens, 96 | ) 97 | 98 | def _mask_targets(self, target, tokenized_lens, speakers): 99 | # cur_idx = 0 100 | cur_idx = tokenized_lens[0] 101 | tokenized_lens = tokenized_lens[1:] 102 | target[:cur_idx] = IGNORE_INDEX 103 | for tokenized_len, speaker in zip(tokenized_lens, speakers): 104 | if speaker.lower() == "human": 105 | target[cur_idx:tokenized_len] = IGNORE_INDEX 106 | cur_idx += tokenized_len 107 | 108 | 109 | def _add_speaker_and_signal(self, header, source, get_conversation=True): 110 | """Add speaker and start/end signal on each round.""" 111 | BEGIN_SIGNAL = "" 112 | END_SIGNAL = "\n" 113 | conversation = header 114 | for sentence in source: 115 | from_str = sentence["from"] 116 | if from_str.lower() == "human": 117 | from_str = conversation_lib.default_conversation.roles[0] 118 | else: 119 | from_str = conversation_lib.default_conversation.roles[1] 120 | 121 | sentence["value"] = sentence["value"] + END_SIGNAL 122 | if get_conversation: 123 | conversation += sentence["value"] 124 | conversation += BEGIN_SIGNAL 125 | return conversation 126 | 127 | def token_processor(self, sources): 128 | """ 129 | Given a list of sources, each is a conversation list. This transform: 130 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 131 | 2. Concatenate conversations together; 132 | 3. Tokenize the concatenated conversation; 133 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. 134 | """ 135 | # add end signal and concatenate together 136 | conversations = [] 137 | for source in sources: 138 | header = '' 139 | conversation = self._add_speaker_and_signal(header, source) 140 | conversations.append(conversation) 141 | 142 | conversations_tokenized = self._tokenize_fn(conversations) 143 | input_ids = conversations_tokenized["input_ids"] 144 | targets = copy.deepcopy(input_ids) 145 | for target, source in zip(targets, sources): 146 | tokenized_lens = self._tokenize_fn([header] + [s["value"] for s in source])["input_ids_lens"] 147 | speakers = [sentence["from"] for sentence in source] 148 | self._mask_targets(target, tokenized_lens, speakers) 149 | 150 | return dict(input_ids=input_ids, labels=targets) 151 | 152 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 153 | # data = self.list_data_dict[i] 154 | data = copy.deepcopy(self.list_data_dict[i]) 155 | 156 | if isinstance(data, dict): 157 | if 'image' in data: 158 | image_path = self.list_image_path[i] 159 | image_file = data['image'] 160 | # TODO this is a bug, because some json has wrong path 161 | 162 | try: 163 | image = Image.open(image_path + image_file).convert('RGB') 164 | except: 165 | print(f'cannot identify image file {image_path+image_file}.') 166 | return self.__getitem__(0) 167 | 168 | try: 169 | image, image_high = self.image_processor(image) 170 | except: 171 | print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!') 172 | return self.__getitem__(0) 173 | 174 | conversations = self.multimodal_processor([data["conversations"]]) 175 | else: 176 | conversations = [data] 177 | 178 | # align with fastchat & llava here, put the conversation into a list for tokenization 179 | data_dict = self.token_processor(conversations) 180 | data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) 181 | 182 | if isinstance(data, dict) and 'image' in data: 183 | data_dict['image'] = [image] 184 | data_dict['image_high'] = [image_high] 185 | else: 186 | crop_size = self.multimodal_cfg['image_processor'].crop_size 187 | data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])] 188 | # TODO sam is 1024*1024 189 | data_dict['image_high'] = [torch.zeros(3, 1024, 1024)] 190 | return data_dict 191 | 192 | -------------------------------------------------------------------------------- /Vary-master/vary/data/conversation_dataset_qwen.py: -------------------------------------------------------------------------------- 1 | 2 | import io 3 | import os 4 | import copy 5 | import json 6 | import logging 7 | import torch 8 | import random 9 | 10 | from typing import List, Optional, Tuple, Union, Dict, Sequence 11 | from PIL import Image, ImageFile 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | from vary.data.base_dataset import BaseDataset 15 | from vary.utils.constants import * 16 | from vary.utils import conversation as conversation_lib 17 | 18 | class ConversationDataset(BaseDataset): 19 | """Conversation format dataset stage2 fine-tuning.""" 20 | 21 | def __init__(self, datasets, tokenizer, multimodal_cfg): 22 | super(ConversationDataset, self).__init__(datasets, tokenizer, multimodal_cfg) 23 | # v0 version format conversation 24 | conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"] 25 | logging.warning("Formatting inputs into conversation type: mpt-fixed") 26 | logging.warning("Loading data...") 27 | 28 | list_data_dict = [] 29 | list_image_path = [] 30 | 31 | 32 | for name in datasets.split("+"): 33 | # for name in vary_data_dict[name_all]: 34 | dataset = CONVERSATION_DATA[name] 35 | 36 | data_path = dataset['annotations'] 37 | data = json.load(open(data_path, "r")) 38 | 39 | list_data_dict.extend(data) 40 | 41 | image_path = dataset['images'] 42 | 43 | list_image_path.extend([image_path] * len(data)) 44 | 45 | logging.warning(f"Data from {data_path} provide {len(data)} conversations.") 46 | 47 | assert len(list_data_dict) == len(list_image_path) 48 | logging.warning(f"{len(list_data_dict)} conversations in total.") 49 | a_new_list = list(zip(list_data_dict, list_image_path)) 50 | random.shuffle(a_new_list) 51 | list_data_dict_new, list_image_path_new = zip(*a_new_list) 52 | self.list_data_dict = list_data_dict_new 53 | self.list_image_path = list_image_path_new 54 | 55 | self.im_patch_token = 151859 56 | 57 | self.im_start_token = 151857 58 | 59 | self.im_end_token = 151858 60 | 61 | def multimodal_processor(self, sources): 62 | for source in sources: 63 | if self.multimodal_cfg['sep_image_conv_front']: 64 | assert DEFAULT_IMAGE_TOKEN in source[0]['value'] 65 | source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() 66 | source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value'] 67 | for sentence in source: 68 | replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len'] 69 | # if self.multimodal_cfg['use_im_start_end']: 70 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN 71 | sentence["value"] = str(sentence["value"]).replace(DEFAULT_IMAGE_TOKEN, replace_token) 72 | return sources 73 | 74 | def _tokenize_fn(self, strings): 75 | """Tokenize a list of strings.""" 76 | tokenized_list = [ 77 | self.tokenizer( 78 | text, 79 | return_tensors="pt", 80 | padding="longest", 81 | max_length=self.tokenizer.model_max_length, 82 | truncation=True, 83 | ) for text in strings 84 | ] 85 | input_ids = labels = [ 86 | tokenized.input_ids[0] for tokenized in tokenized_list 87 | ] 88 | input_ids_lens = labels_lens = [ 89 | tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item() 90 | for tokenized in tokenized_list 91 | ] 92 | return dict( 93 | input_ids=input_ids, 94 | labels=labels, 95 | input_ids_lens=input_ids_lens, 96 | labels_lens=labels_lens, 97 | ) 98 | 99 | def _mask_targets(self, target, tokenized_lens, speakers): 100 | # cur_idx = 0 101 | cur_idx = tokenized_lens[0] 102 | tokenized_lens = tokenized_lens[1:] 103 | target[:cur_idx] = IGNORE_INDEX 104 | for tokenized_len, speaker in zip(tokenized_lens, speakers): 105 | if speaker.lower() == "human": 106 | target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX 107 | cur_idx += tokenized_len 108 | 109 | def token_processor(self, sources): 110 | conv = conversation_lib.default_conversation.copy() 111 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 112 | 113 | # Apply prompt templates 114 | conversations = [] 115 | for i, source in enumerate(sources): 116 | if roles[source[0]["from"]] != conv.roles[0]: 117 | # Skip the first one if it is not from human 118 | source = source[1:] 119 | 120 | conv.messages = [] 121 | for j, sentence in enumerate(source): 122 | role = roles[sentence["from"]] 123 | assert role == conv.roles[j % 2], f"{i}" 124 | conv.append_message(role, sentence["value"]) 125 | conversations.append(conv.get_prompt()) 126 | 127 | # Tokenize conversations 128 | 129 | 130 | input_ids = self.tokenizer( 131 | conversations, 132 | return_tensors="pt", 133 | padding="longest", 134 | max_length=self.tokenizer.model_max_length, 135 | truncation=True, 136 | ).input_ids 137 | 138 | # input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) 139 | targets = input_ids.clone() 140 | assert conv.sep_style == conversation_lib.SeparatorStyle.MPT 141 | 142 | # Mask targets 143 | sep = conv.sep + conv.roles[1] 144 | for conversation, target in zip(conversations, targets): 145 | total_len = int(target.ne(self.tokenizer.pad_token_id).sum()) 146 | 147 | rounds = conversation.split(conv.sep) 148 | re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt 149 | for conv_idx in range(3, len(rounds), 2): 150 | re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt 151 | cur_len = 0 152 | target[:cur_len] = IGNORE_INDEX 153 | for i, rou in enumerate(re_rounds): 154 | if rou == "": 155 | break 156 | 157 | parts = rou.split(sep) 158 | if len(parts) != 2: 159 | break 160 | parts[0] += sep 161 | round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids) 162 | 163 | instruction_len = len(self.tokenizer(parts[0]).input_ids) 164 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 165 | 166 | cur_len += round_len 167 | target[cur_len:] = IGNORE_INDEX 168 | 169 | if cur_len < self.tokenizer.model_max_length: 170 | if cur_len != total_len: 171 | target[:] = IGNORE_INDEX 172 | print( 173 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 174 | f" (ignored)" 175 | ) 176 | 177 | return dict( 178 | input_ids=input_ids, 179 | labels=targets, 180 | ) 181 | 182 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 183 | # data = self.list_data_dict[i] 184 | data = copy.deepcopy(self.list_data_dict[i]) 185 | 186 | if isinstance(data, dict): 187 | if 'image' in data: 188 | image_path = self.list_image_path[i] 189 | image_file = data['image'] 190 | 191 | try: 192 | image = Image.open(image_path + image_file).convert('RGB') 193 | except: 194 | print(f'cannot identify image file {image_path + image_file}.') 195 | return self.__getitem__(0) 196 | 197 | try: 198 | image, image_1 = self.image_processor(image) 199 | except: 200 | print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!') 201 | return self.__getitem__(0) 202 | 203 | conversations = self.multimodal_processor([data["conversations"]]) 204 | 205 | else: 206 | conversations = [data] 207 | 208 | # align with fastchat & llava here, put the conversation into a list for tokenization 209 | data_dict = self.token_processor(conversations) 210 | data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) 211 | 212 | if isinstance(data, dict) and 'image' in data: 213 | data_dict['image'] = [image] 214 | data_dict['image_high'] = [image_1] 215 | else: 216 | crop_size = self.multimodal_cfg['image_processor'].crop_size 217 | data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])] 218 | data_dict['image_high'] = [torch.zeros(3, 1024, 1024)] 219 | return data_dict 220 | 221 | -------------------------------------------------------------------------------- /Vary-master/vary/demo/run_opt.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM 3 | import torch 4 | import os 5 | from vary.utils.conversation import conv_templates, SeparatorStyle 6 | from vary.utils.utils import disable_torch_init 7 | from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria 8 | from vary.model import * 9 | from vary.utils.utils import KeywordsStoppingCriteria 10 | 11 | from PIL import Image 12 | 13 | import os 14 | import requests 15 | from PIL import Image 16 | from io import BytesIO 17 | 18 | from transformers import TextStreamer 19 | 20 | 21 | from vary.model.plug.blip_process import BlipImageEvalProcessor 22 | 23 | from vary.model.vision_encoder.sam import build_sam_vit_b 24 | from vary.model.plug.transforms import train_transform, test_transform 25 | DEFAULT_IMAGE_TOKEN = "" 26 | DEFAULT_IMAGE_PATCH_TOKEN = '' 27 | DEFAULT_IM_START_TOKEN = '' 28 | DEFAULT_IM_END_TOKEN = '' 29 | 30 | 31 | 32 | def load_image(image_file): 33 | if image_file.startswith('http') or image_file.startswith('https'): 34 | response = requests.get(image_file) 35 | image = Image.open(BytesIO(response.content)).convert('RGB') 36 | else: 37 | image = Image.open(image_file).convert('RGB') 38 | return image 39 | 40 | 41 | def eval_model(args): 42 | # Model 43 | disable_torch_init() 44 | model_name = os.path.expanduser(args.model_name) 45 | 46 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) 47 | 48 | model = varyOPTForCausalLM.from_pretrained(model_name) 49 | 50 | 51 | 52 | model.to(device='cuda', dtype=torch.bfloat16) 53 | 54 | image_processor_high = test_transform 55 | 56 | 57 | image_token_len = 256 58 | 59 | 60 | prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN 61 | inputs = tokenizer([prompt]) 62 | 63 | 64 | image = load_image(args.image_file) 65 | image_1 = image.copy() 66 | 67 | image_tensor_1 = image_processor_high(image_1).to(torch.bfloat16) 68 | 69 | 70 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 71 | 72 | stop_str = '' 73 | keywords = [stop_str] 74 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 75 | 76 | 77 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 78 | 79 | with torch.autocast("cuda", dtype=torch.bfloat16): 80 | output_ids = model.generate( 81 | input_ids, 82 | images=[(image_tensor_1.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).cuda())], 83 | do_sample=True, 84 | num_beams = 1, 85 | streamer=streamer, 86 | max_new_tokens=4096, 87 | stopping_criteria=[stopping_criteria] 88 | ) 89 | 90 | 91 | 92 | 93 | # input_token_len = input_ids.shape[1] 94 | # outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 95 | 96 | # if outputs.endswith(stop_str): 97 | # outputs = outputs[:-len(stop_str)] 98 | # outputs = outputs.strip() 99 | 100 | # print(outputs) 101 | 102 | 103 | if __name__ == "__main__": 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 106 | parser.add_argument("--image-file", type=str, required=True) 107 | # parser.add_argument("--query", type=str, required=True) 108 | parser.add_argument("--conv-mode", type=str, default=None) 109 | args = parser.parse_args() 110 | 111 | eval_model(args) 112 | -------------------------------------------------------------------------------- /Vary-master/vary/demo/run_qwen_vary.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoTokenizer, AutoModelForCausalLM 3 | import torch 4 | import os 5 | from vary.utils.conversation import conv_templates, SeparatorStyle 6 | from vary.utils.utils import disable_torch_init 7 | from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria 8 | from vary.model import * 9 | from vary.utils.utils import KeywordsStoppingCriteria 10 | 11 | from PIL import Image 12 | 13 | import os 14 | import requests 15 | from PIL import Image 16 | from io import BytesIO 17 | from vary.model.plug.blip_process import BlipImageEvalProcessor 18 | from transformers import TextStreamer 19 | from vary.model.plug.transforms import train_transform, test_transform 20 | 21 | DEFAULT_IMAGE_TOKEN = "" 22 | DEFAULT_IMAGE_PATCH_TOKEN = '' 23 | DEFAULT_IM_START_TOKEN = '' 24 | DEFAULT_IM_END_TOKEN = '' 25 | 26 | 27 | def load_image(image_file): 28 | if image_file.startswith('http') or image_file.startswith('https'): 29 | response = requests.get(image_file) 30 | image = Image.open(BytesIO(response.content)).convert('RGB') 31 | else: 32 | image = Image.open(image_file).convert('RGB') 33 | return image 34 | 35 | 36 | def eval_model(args): 37 | # Model 38 | disable_torch_init() 39 | model_name = os.path.expanduser(args.model_name) 40 | 41 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 42 | 43 | model = varyQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', trust_remote_code=True) 44 | 45 | 46 | model.to(device='cuda', dtype=torch.bfloat16) 47 | 48 | 49 | image_processor = CLIPImageProcessor.from_pretrained("/data/hypertext/ucaswei/cache/vit-large-patch14/vit-large-patch14/", torch_dtype=torch.float16) 50 | 51 | image_processor_high = BlipImageEvalProcessor(image_size=1024) 52 | 53 | use_im_start_end = True 54 | 55 | image_token_len = 256 56 | 57 | qs = 'Provide the ocr results of this image.' 58 | 59 | if use_im_start_end: 60 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs 61 | else: 62 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs 63 | 64 | 65 | 66 | 67 | conv_mode = "mpt" 68 | args.conv_mode = conv_mode 69 | 70 | conv = conv_templates[args.conv_mode].copy() 71 | conv.append_message(conv.roles[0], qs) 72 | conv.append_message(conv.roles[1], None) 73 | prompt = conv.get_prompt() 74 | 75 | 76 | inputs = tokenizer([prompt]) 77 | 78 | 79 | image = load_image(args.image_file) 80 | image_1 = image.copy() 81 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 82 | 83 | image_tensor_1 = image_processor_high(image_1) 84 | 85 | input_ids = torch.as_tensor(inputs.input_ids).cuda() 86 | 87 | # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 88 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 89 | keywords = [stop_str] 90 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 91 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 92 | 93 | 94 | with torch.autocast("cuda", dtype=torch.bfloat16): 95 | output_ids = model.generate( 96 | input_ids, 97 | images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())], 98 | do_sample=True, 99 | num_beams = 1, 100 | # temperature=0.2, 101 | streamer=streamer, 102 | max_new_tokens=2048, 103 | stopping_criteria=[stopping_criteria] 104 | ) 105 | 106 | # print(output_ids) 107 | 108 | # outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 109 | 110 | # # conv.messages[-1][-1] = outputs 111 | # if outputs.endswith(stop_str): 112 | # outputs = outputs[:-len(stop_str)] 113 | # outputs = outputs.strip() 114 | 115 | # print(outputs) 116 | 117 | 118 | 119 | if __name__ == "__main__": 120 | parser = argparse.ArgumentParser() 121 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m") 122 | parser.add_argument("--image-file", type=str, required=True) 123 | parser.add_argument("--conv-mode", type=str, default=None) 124 | args = parser.parse_args() 125 | 126 | eval_model(args) 127 | -------------------------------------------------------------------------------- /Vary-master/vary/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .vary_opt import varyOPTModel, varyOPTForCausalLM 2 | # from .vary_qwen_vary import varyQwenModel, varyQwenForCausalLM, varyConfig 3 | from .vary_toy_qwen1_8 import varyQwenModel, varyQwenForCausalLM, varyConfig 4 | 5 | -------------------------------------------------------------------------------- /Vary-master/vary/model/llm/qwen/configuration_qwen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from transformers import PretrainedConfig 7 | 8 | 9 | class QWenConfig(PretrainedConfig): 10 | model_type = "qwen" 11 | keys_to_ignore_at_inference = ["past_key_values"] 12 | 13 | def __init__( 14 | self, 15 | vocab_size=151936, 16 | hidden_size=4096, 17 | num_hidden_layers=32, 18 | num_attention_heads=32, 19 | emb_dropout_prob=0.0, 20 | attn_dropout_prob=0.0, 21 | layer_norm_epsilon=1e-6, 22 | initializer_range=0.02, 23 | max_position_embeddings=8192, 24 | scale_attn_weights=True, 25 | use_cache=True, 26 | bf16=False, 27 | fp16=False, 28 | fp32=False, 29 | kv_channels=128, 30 | rotary_pct=1.0, 31 | rotary_emb_base=10000, 32 | use_dynamic_ntk=True, 33 | use_logn_attn=True, 34 | use_flash_attn="auto", 35 | intermediate_size=22016, 36 | no_bias=True, 37 | tie_word_embeddings=False, 38 | **kwargs, 39 | ): 40 | self.vocab_size = vocab_size 41 | self.hidden_size = hidden_size 42 | self.intermediate_size = intermediate_size 43 | self.num_hidden_layers = num_hidden_layers 44 | self.num_attention_heads = num_attention_heads 45 | self.emb_dropout_prob = emb_dropout_prob 46 | self.attn_dropout_prob = attn_dropout_prob 47 | self.layer_norm_epsilon = layer_norm_epsilon 48 | self.initializer_range = initializer_range 49 | self.scale_attn_weights = scale_attn_weights 50 | self.use_cache = use_cache 51 | self.max_position_embeddings = max_position_embeddings 52 | self.bf16 = bf16 53 | self.fp16 = fp16 54 | self.fp32 = fp32 55 | self.kv_channels = kv_channels 56 | self.rotary_pct = rotary_pct 57 | self.rotary_emb_base = rotary_emb_base 58 | self.use_dynamic_ntk = use_dynamic_ntk 59 | self.use_logn_attn = use_logn_attn 60 | self.use_flash_attn = use_flash_attn 61 | self.no_bias = no_bias 62 | super().__init__( 63 | tie_word_embeddings=tie_word_embeddings, 64 | **kwargs 65 | ) 66 | -------------------------------------------------------------------------------- /Vary-master/vary/model/llm/qwen/qwen_generation_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Generation support.""" 7 | 8 | from typing import Tuple, List, Union, Iterable 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from transformers import PreTrainedTokenizer 14 | from transformers import logging 15 | from transformers.generation import LogitsProcessor 16 | 17 | logger = logging.get_logger(__name__) 18 | 19 | # Types. 20 | HistoryType = List[Tuple[str, str]] 21 | TokensType = List[int] 22 | BatchTokensType = List[List[int]] 23 | 24 | 25 | def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType: 26 | for tokens in batch: 27 | context_length = len(tokens) 28 | if context_length < seq_length: 29 | tokens.extend([pad_id] * (seq_length - context_length)) 30 | return batch 31 | 32 | 33 | def get_ltor_masks_and_position_ids( 34 | data, 35 | eod_token, 36 | reset_position_ids, 37 | reset_attention_mask, 38 | eod_mask_loss, 39 | ): 40 | """Build masks and position id for left to right model.""" 41 | 42 | # Extract batch size and sequence length. 43 | micro_batch_size, seq_length = data.size() 44 | 45 | # Attention mask (lower triangular). 46 | if reset_attention_mask: 47 | att_mask_batch = micro_batch_size 48 | else: 49 | att_mask_batch = 1 50 | attention_mask = torch.tril( 51 | torch.ones((att_mask_batch, seq_length, seq_length), device=data.device) 52 | ).view(att_mask_batch, 1, seq_length, seq_length) 53 | 54 | # Loss mask. 55 | loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) 56 | if eod_mask_loss: 57 | loss_mask[data == eod_token] = 0.0 58 | 59 | # Position ids. 60 | position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device) 61 | position_ids = position_ids.unsqueeze(0).expand_as(data) 62 | # We need to clone as the ids will be modifed based on batch index. 63 | if reset_position_ids: 64 | position_ids = position_ids.clone() 65 | 66 | if reset_position_ids or reset_attention_mask: 67 | # Loop through the batches: 68 | for b in range(micro_batch_size): 69 | 70 | # Find indecies where EOD token is. 71 | eod_index = position_ids[b, data[b] == eod_token] 72 | # Detach indecies from positions if going to modify positions. 73 | if reset_position_ids: 74 | eod_index = eod_index.clone() 75 | 76 | # Loop through EOD indecies: 77 | prev_index = 0 78 | for j in range(eod_index.size()[0]): 79 | i = eod_index[j] 80 | # Mask attention loss. 81 | if reset_attention_mask: 82 | attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0 83 | # Reset positions. 84 | if reset_position_ids: 85 | position_ids[b, (i + 1) :] -= i + 1 - prev_index 86 | prev_index = i + 1 87 | 88 | # Convert attention mask to binary: 89 | attention_mask = attention_mask < 0.5 90 | 91 | return attention_mask, loss_mask, position_ids 92 | 93 | 94 | def get_batch(context_tokens: torch.LongTensor, eod_id: int): 95 | """Generate batch from context tokens.""" 96 | # Move to GPU. 97 | tokens = context_tokens.contiguous().to(context_tokens.device) 98 | # Get the attention mask and postition ids. 99 | attention_mask, _, position_ids = get_ltor_masks_and_position_ids( 100 | tokens, 101 | eod_id, 102 | reset_position_ids=False, 103 | reset_attention_mask=False, 104 | eod_mask_loss=False, 105 | ) 106 | return tokens, attention_mask, position_ids 107 | 108 | 109 | def get_stop_words_ids(chat_format, tokenizer): 110 | if chat_format == "raw": 111 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]] 112 | elif chat_format == "chatml": 113 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] 114 | else: 115 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 116 | return stop_words_ids 117 | 118 | 119 | def make_context( 120 | tokenizer: PreTrainedTokenizer, 121 | query: str, 122 | history: List[Tuple[str, str]] = None, 123 | system: str = "", 124 | max_window_size: int = 6144, 125 | chat_format: str = "chatml", 126 | ): 127 | if history is None: 128 | history = [] 129 | 130 | if chat_format == "chatml": 131 | im_start, im_end = "<|im_start|>", "<|im_end|>" 132 | im_start_tokens = [tokenizer.im_start_id] 133 | im_end_tokens = [tokenizer.im_end_id] 134 | nl_tokens = tokenizer.encode("\n") 135 | 136 | def _tokenize_str(role, content): 137 | return f"{role}\n{content}", tokenizer.encode( 138 | role, allowed_special=set(tokenizer.IMAGE_ST) 139 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST)) 140 | 141 | system_text, system_tokens_part = _tokenize_str("system", system) 142 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens 143 | 144 | raw_text = "" 145 | context_tokens = [] 146 | 147 | for turn_query, turn_response in reversed(history): 148 | query_text, query_tokens_part = _tokenize_str("user", turn_query) 149 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens 150 | if turn_response is not None: 151 | response_text, response_tokens_part = _tokenize_str( 152 | "assistant", turn_response 153 | ) 154 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens 155 | 156 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens 157 | prev_chat = ( 158 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" 159 | ) 160 | else: 161 | next_context_tokens = nl_tokens + query_tokens + nl_tokens 162 | prev_chat = f"\n{im_start}{query_text}{im_end}\n" 163 | 164 | current_context_size = ( 165 | len(system_tokens) + len(next_context_tokens) + len(context_tokens) 166 | ) 167 | if current_context_size < max_window_size: 168 | context_tokens = next_context_tokens + context_tokens 169 | raw_text = prev_chat + raw_text 170 | else: 171 | break 172 | 173 | context_tokens = system_tokens + context_tokens 174 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text 175 | context_tokens += ( 176 | nl_tokens 177 | + im_start_tokens 178 | + _tokenize_str("user", query)[1] 179 | + im_end_tokens 180 | + nl_tokens 181 | + im_start_tokens 182 | + tokenizer.encode("assistant") 183 | + nl_tokens 184 | ) 185 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" 186 | 187 | elif chat_format == "raw": 188 | raw_text = query 189 | context_tokens = tokenizer.encode(raw_text) 190 | else: 191 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 192 | 193 | return raw_text, context_tokens 194 | 195 | 196 | def _decode_default( 197 | tokens: List[int], 198 | *, 199 | stop_words: List[str], 200 | eod_words: List[str], 201 | tokenizer: PreTrainedTokenizer, 202 | raw_text_len: int, 203 | verbose: bool = False, 204 | return_end_reason: bool = False, 205 | errors: str='replace', 206 | ): 207 | trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:] 208 | if verbose: 209 | print("\nRaw Generate: ", trim_decode_tokens) 210 | 211 | end_reason = f"Gen length {len(tokens)}" 212 | for stop_word in stop_words: 213 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() 214 | for eod_word in eod_words: 215 | if eod_word in trim_decode_tokens: 216 | end_reason = f"Gen {eod_word!r}" 217 | trim_decode_tokens = trim_decode_tokens.split(eod_word)[0] 218 | trim_decode_tokens = trim_decode_tokens.strip() 219 | if verbose: 220 | print("\nEnd Reason:", end_reason) 221 | print("\nGenerate: ", trim_decode_tokens) 222 | 223 | if return_end_reason: 224 | return trim_decode_tokens, end_reason 225 | else: 226 | return trim_decode_tokens 227 | 228 | 229 | def _decode_chatml( 230 | tokens: List[int], 231 | *, 232 | stop_words: List[str], 233 | eod_token_ids: List[int], 234 | tokenizer: PreTrainedTokenizer, 235 | raw_text_len: int, 236 | context_length: int, 237 | verbose: bool = False, 238 | return_end_reason: bool = False, 239 | errors: str='replace' 240 | ): 241 | end_reason = f"Gen length {len(tokens)}" 242 | eod_token_idx = context_length 243 | for eod_token_idx in range(context_length, len(tokens)): 244 | if tokens[eod_token_idx] in eod_token_ids: 245 | end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}" 246 | break 247 | 248 | trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:] 249 | if verbose: 250 | print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:]) 251 | print("\nRaw Generate:", trim_decode_tokens) 252 | print("\nEnd Reason:", end_reason) 253 | for stop_word in stop_words: 254 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip() 255 | trim_decode_tokens = trim_decode_tokens.strip() 256 | if verbose: 257 | print("\nGenerate:", trim_decode_tokens) 258 | 259 | if return_end_reason: 260 | return trim_decode_tokens, end_reason 261 | else: 262 | return trim_decode_tokens 263 | 264 | 265 | def decode_tokens( 266 | tokens: Union[torch.LongTensor, TokensType], 267 | tokenizer: PreTrainedTokenizer, 268 | raw_text_len: int, 269 | context_length: int, 270 | chat_format: str, 271 | verbose: bool = False, 272 | return_end_reason: bool = False, 273 | errors: str="replace", 274 | ) -> str: 275 | if torch.is_tensor(tokens): 276 | tokens = tokens.cpu().numpy().tolist() 277 | 278 | if chat_format == "chatml": 279 | return _decode_chatml( 280 | tokens, 281 | stop_words=[], 282 | eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id], 283 | tokenizer=tokenizer, 284 | raw_text_len=raw_text_len, 285 | context_length=context_length, 286 | verbose=verbose, 287 | return_end_reason=return_end_reason, 288 | errors=errors, 289 | ) 290 | elif chat_format == "raw": 291 | return _decode_default( 292 | tokens, 293 | stop_words=["<|endoftext|>"], 294 | eod_words=["<|endoftext|>"], 295 | tokenizer=tokenizer, 296 | raw_text_len=raw_text_len, 297 | verbose=verbose, 298 | return_end_reason=return_end_reason, 299 | errors=errors, 300 | ) 301 | else: 302 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 303 | 304 | 305 | class StopWordsLogitsProcessor(LogitsProcessor): 306 | """ 307 | :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. 308 | 309 | Args: 310 | stop_words_ids (:obj:`List[List[int]]`): 311 | List of list of token ids of stop ids. In order to get the tokens of the words 312 | that should not appear in the generated text, use :obj:`tokenizer(bad_word, 313 | add_prefix_space=True).input_ids`. 314 | eos_token_id (:obj:`int`): 315 | The id of the `end-of-sequence` token. 316 | """ 317 | 318 | def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): 319 | 320 | if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: 321 | raise ValueError( 322 | f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}." 323 | ) 324 | if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): 325 | raise ValueError( 326 | f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}." 327 | ) 328 | if any( 329 | any( 330 | (not isinstance(token_id, (int, np.integer)) or token_id < 0) 331 | for token_id in stop_word_ids 332 | ) 333 | for stop_word_ids in stop_words_ids 334 | ): 335 | raise ValueError( 336 | f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." 337 | ) 338 | 339 | self.stop_words_ids = list( 340 | filter( 341 | lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids 342 | ) 343 | ) 344 | self.eos_token_id = eos_token_id 345 | for stop_token_seq in self.stop_words_ids: 346 | assert ( 347 | len(stop_token_seq) > 0 348 | ), "Stop words token sequences {} cannot have an empty list".format( 349 | stop_words_ids 350 | ) 351 | 352 | def __call__( 353 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 354 | ) -> torch.FloatTensor: 355 | stopped_samples = self._calc_stopped_samples(input_ids) 356 | for i, should_stop in enumerate(stopped_samples): 357 | if should_stop: 358 | scores[i, self.eos_token_id] = float(2**15) 359 | return scores 360 | 361 | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: 362 | if len(tokens) == 0: 363 | # if bad word tokens is just one token always ban it 364 | return True 365 | elif len(tokens) > len(prev_tokens): 366 | # if bad word tokens are longer then prev input_ids they can't be equal 367 | return False 368 | elif prev_tokens[-len(tokens) :].tolist() == tokens: 369 | # if tokens match 370 | return True 371 | else: 372 | return False 373 | 374 | def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]: 375 | stopped_samples = [] 376 | for prev_input_ids_slice in prev_input_ids: 377 | match = False 378 | for stop_token_seq in self.stop_words_ids: 379 | if self._tokens_match(prev_input_ids_slice, stop_token_seq): 380 | # if tokens do not match continue 381 | match = True 382 | break 383 | stopped_samples.append(match) 384 | 385 | return stopped_samples 386 | 387 | 388 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")): 389 | """This function has been mostly taken from huggingface conversational 390 | ai code at 391 | https://medium.com/huggingface/how-to-build-a-state-of-the-art- 392 | conversational-ai-with-transfer-learning-2d818ac26313""" 393 | 394 | if top_k > 0: 395 | # Remove all tokens with a probability less than the 396 | # last token of the top-k 397 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 398 | logits[indices_to_remove] = filter_value 399 | 400 | if top_p > 0.0: 401 | # Cconvert to 1D 402 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) 403 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 404 | 405 | # Remove tokens with cumulative probability above the threshold 406 | sorted_indices_to_remove = cumulative_probs > top_p 407 | # Shift the indices to the right to keep also the first token 408 | # above the threshold 409 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 410 | sorted_indices_to_remove[..., 0] = 0 411 | for i in range(sorted_indices.size(0)): 412 | indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] 413 | logits[i][indices_to_remove] = filter_value 414 | 415 | return logits 416 | 417 | 418 | def switch(val1, val2, boolean): 419 | boolean = boolean.type_as(val1) 420 | return (1 - boolean) * val1 + boolean * val2 421 | -------------------------------------------------------------------------------- /Vary-master/vary/model/llm/qwen/tokenization_qwen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | """Tokenization classes for QWen.""" 7 | 8 | import base64 9 | import logging 10 | import os 11 | import unicodedata 12 | from typing import Collection, Dict, List, Set, Tuple, Union 13 | 14 | import tiktoken 15 | from transformers import PreTrainedTokenizer, AddedToken 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"} 21 | 22 | PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" 23 | ENDOFTEXT = "<|endoftext|>" 24 | IMSTART = "<|im_start|>" 25 | IMEND = "<|im_end|>" 26 | # as the default behavior is changed to allow special tokens in 27 | # regular texts, the surface forms of special tokens need to be 28 | # as different as possible to minimize the impact 29 | EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205))) 30 | SPECIAL_TOKENS = ( 31 | ENDOFTEXT, 32 | IMSTART, 33 | IMEND, 34 | ) + EXTRAS 35 | 36 | 37 | def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]: 38 | with open(tiktoken_bpe_file, "rb") as f: 39 | contents = f.read() 40 | return { 41 | base64.b64decode(token): int(rank) 42 | for token, rank in (line.split() for line in contents.splitlines() if line) 43 | } 44 | 45 | class QWenTokenizer(PreTrainedTokenizer): 46 | """QWen tokenizer.""" 47 | 48 | vocab_files_names = VOCAB_FILES_NAMES 49 | 50 | def __init__( 51 | self, 52 | vocab_file, 53 | errors="replace", 54 | **kwargs, 55 | ): 56 | super().__init__(**kwargs) 57 | 58 | self.errors = errors # how to handle errors in decoding 59 | 60 | self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int] 61 | self.special_tokens = { 62 | token: index 63 | for index, token in enumerate( 64 | SPECIAL_TOKENS, start=len(self.mergeable_ranks) 65 | ) 66 | } 67 | 68 | enc = tiktoken.Encoding( 69 | "Qwen", 70 | pat_str=PAT_STR, 71 | mergeable_ranks=self.mergeable_ranks, 72 | special_tokens=self.special_tokens, 73 | ) 74 | assert ( 75 | len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab 76 | ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding" 77 | 78 | self.decoder = { 79 | v: k for k, v in self.mergeable_ranks.items() 80 | } # type: dict[int, bytes|str] 81 | self.decoder.update({v: k for k, v in self.special_tokens.items()}) 82 | 83 | self.tokenizer = enc # type: tiktoken.Encoding 84 | 85 | self.eod_id = self.tokenizer.eot_token 86 | self.im_start_id = self.special_tokens[IMSTART] 87 | self.im_end_id = self.special_tokens[IMEND] 88 | 89 | def __len__(self) -> int: 90 | return self.tokenizer.n_vocab 91 | 92 | def get_vocab(self) -> Dict[bytes, int]: 93 | return self.mergeable_ranks 94 | 95 | def convert_tokens_to_ids( 96 | self, tokens: Union[bytes, str, List[Union[bytes, str]]] 97 | ) -> List[int]: 98 | ids = [] 99 | if isinstance(tokens, (str, bytes)): 100 | if tokens in self.special_tokens: 101 | return self.special_tokens[tokens] 102 | else: 103 | return self.mergeable_ranks.get(tokens) 104 | for token in tokens: 105 | if token in self.special_tokens: 106 | ids.append(self.special_tokens[token]) 107 | else: 108 | ids.append(self.mergeable_ranks.get(token)) 109 | return ids 110 | 111 | def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int: 112 | if not special_tokens and new_tokens: 113 | raise ValueError('Adding regular tokens is not supported') 114 | for token in new_tokens: 115 | surface_form = token.content if isinstance(token, AddedToken) else token 116 | if surface_form not in SPECIAL_TOKENS: 117 | raise ValueError('Adding unknown special tokens is not supported') 118 | return 0 119 | 120 | def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]: 121 | """ 122 | Save only the vocabulary of the tokenizer (vocabulary). 123 | 124 | Returns: 125 | `Tuple(str)`: Paths to the files saved. 126 | """ 127 | file_path = os.path.join(save_directory, "qwen.tiktoken") 128 | with open(file_path, "w", encoding="utf8") as w: 129 | for k, v in self.mergeable_ranks.items(): 130 | line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n" 131 | w.write(line) 132 | return (file_path,) 133 | 134 | def tokenize( 135 | self, 136 | text: str, 137 | allowed_special: Union[Set, str] = "all", 138 | disallowed_special: Union[Collection, str] = (), 139 | **kwargs, 140 | ) -> List[Union[bytes, str]]: 141 | """ 142 | Converts a string in a sequence of tokens. 143 | 144 | Args: 145 | text (`str`): 146 | The sequence to be encoded. 147 | allowed_special (`Literal["all"]` or `set`): 148 | The surface forms of the tokens to be encoded as special tokens in regular texts. 149 | Default to "all". 150 | disallowed_special (`Literal["all"]` or `Collection`): 151 | The surface forms of the tokens that should not be in regular texts and trigger errors. 152 | Default to an empty tuple. 153 | 154 | kwargs (additional keyword arguments, *optional*): 155 | Will be passed to the underlying model specific encode method. 156 | 157 | Returns: 158 | `List[bytes|str]`: The list of tokens. 159 | """ 160 | tokens = [] 161 | text = unicodedata.normalize("NFC", text) 162 | 163 | # this implementation takes a detour: text -> token id -> token surface forms 164 | for t in self.tokenizer.encode( 165 | text, allowed_special=allowed_special, disallowed_special=disallowed_special 166 | ): 167 | tokens.append(self.decoder[t]) 168 | return tokens 169 | 170 | def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str: 171 | """ 172 | Converts a sequence of tokens in a single string. 173 | """ 174 | text = "" 175 | temp = b"" 176 | for t in tokens: 177 | if isinstance(t, str): 178 | if temp: 179 | text += temp.decode("utf-8", errors=self.errors) 180 | temp = b"" 181 | text += t 182 | elif isinstance(t, bytes): 183 | temp += t 184 | else: 185 | raise TypeError("token should only be of type types or str") 186 | if temp: 187 | text += temp.decode("utf-8", errors=self.errors) 188 | return text 189 | 190 | @property 191 | def vocab_size(self): 192 | return self.tokenizer.n_vocab 193 | 194 | def _convert_id_to_token(self, index: int) -> Union[bytes, str]: 195 | """Converts an id to a token, special tokens included""" 196 | if index in self.decoder: 197 | return self.decoder[index] 198 | raise ValueError("unknown ids") 199 | 200 | def _convert_token_to_id(self, token: Union[bytes, str]) -> int: 201 | """Converts a token to an id using the vocab, special tokens included""" 202 | if token in self.special_tokens: 203 | return self.special_tokens[token] 204 | if token in self.mergeable_ranks: 205 | return self.mergeable_ranks[token] 206 | raise ValueError("unknown token") 207 | 208 | def _tokenize(self, text: str, **kwargs): 209 | """ 210 | Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based 211 | vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces). 212 | 213 | Do NOT take care of added tokens. 214 | """ 215 | raise NotImplementedError 216 | 217 | def _decode( 218 | self, 219 | token_ids: Union[int, List[int]], 220 | skip_special_tokens: bool = False, 221 | errors: str = None, 222 | **kwargs, 223 | ) -> str: 224 | if isinstance(token_ids, int): 225 | token_ids = [token_ids] 226 | if skip_special_tokens: 227 | token_ids = [i for i in token_ids if i < self.eod_id] 228 | return self.tokenizer.decode(token_ids, errors=errors or self.errors) 229 | -------------------------------------------------------------------------------- /Vary-master/vary/model/plug/blip_process.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) 2022, salesforce.com, inc. 3 | All rights reserved. 4 | SPDX-License-Identifier: BSD-3-Clause 5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause 6 | """ 7 | 8 | import cv2 9 | import numpy as np 10 | 11 | import torch 12 | 13 | # from omegaconf import OmegaConf 14 | from torchvision import transforms 15 | from torchvision.transforms.functional import InterpolationMode 16 | from PIL import Image 17 | 18 | class BaseProcessor: 19 | def __init__(self): 20 | self.transform = lambda x: x 21 | return 22 | 23 | def __call__(self, item): 24 | return self.transform(item) 25 | 26 | # @classmethod 27 | # def from_config(cls, cfg=None): 28 | # return cls() 29 | 30 | # def build(self, **kwargs): 31 | # cfg = OmegaConf.create(kwargs) 32 | 33 | # return self.from_config(cfg) 34 | 35 | class BlipImageBaseProcessor(BaseProcessor): 36 | def __init__(self, mean=None, std=None): 37 | if mean is None: 38 | mean = (0.48145466, 0.4578275, 0.40821073) 39 | if std is None: 40 | std = (0.26862954, 0.26130258, 0.27577711) 41 | 42 | self.normalize = transforms.Normalize(mean, std) 43 | 44 | 45 | ## aug functions 46 | def identity_func(img): 47 | return img 48 | 49 | 50 | def autocontrast_func(img, cutoff=0): 51 | """ 52 | same output as PIL.ImageOps.autocontrast 53 | """ 54 | n_bins = 256 55 | 56 | def tune_channel(ch): 57 | n = ch.size 58 | cut = cutoff * n // 100 59 | if cut == 0: 60 | high, low = ch.max(), ch.min() 61 | else: 62 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 63 | low = np.argwhere(np.cumsum(hist) > cut) 64 | low = 0 if low.shape[0] == 0 else low[0] 65 | high = np.argwhere(np.cumsum(hist[::-1]) > cut) 66 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0] 67 | if high <= low: 68 | table = np.arange(n_bins) 69 | else: 70 | scale = (n_bins - 1) / (high - low) 71 | offset = -low * scale 72 | table = np.arange(n_bins) * scale + offset 73 | table[table < 0] = 0 74 | table[table > n_bins - 1] = n_bins - 1 75 | table = table.clip(0, 255).astype(np.uint8) 76 | return table[ch] 77 | 78 | channels = [tune_channel(ch) for ch in cv2.split(img)] 79 | out = cv2.merge(channels) 80 | return out 81 | 82 | 83 | def equalize_func(img): 84 | """ 85 | same output as PIL.ImageOps.equalize 86 | PIL's implementation is different from cv2.equalize 87 | """ 88 | n_bins = 256 89 | 90 | def tune_channel(ch): 91 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins]) 92 | non_zero_hist = hist[hist != 0].reshape(-1) 93 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1) 94 | if step == 0: 95 | return ch 96 | n = np.empty_like(hist) 97 | n[0] = step // 2 98 | n[1:] = hist[:-1] 99 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8) 100 | return table[ch] 101 | 102 | channels = [tune_channel(ch) for ch in cv2.split(img)] 103 | out = cv2.merge(channels) 104 | return out 105 | 106 | 107 | def rotate_func(img, degree, fill=(0, 0, 0)): 108 | """ 109 | like PIL, rotate by degree, not radians 110 | """ 111 | H, W = img.shape[0], img.shape[1] 112 | center = W / 2, H / 2 113 | M = cv2.getRotationMatrix2D(center, degree, 1) 114 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill) 115 | return out 116 | 117 | 118 | def solarize_func(img, thresh=128): 119 | """ 120 | same output as PIL.ImageOps.posterize 121 | """ 122 | table = np.array([el if el < thresh else 255 - el for el in range(256)]) 123 | table = table.clip(0, 255).astype(np.uint8) 124 | out = table[img] 125 | return out 126 | 127 | 128 | def color_func(img, factor): 129 | """ 130 | same output as PIL.ImageEnhance.Color 131 | """ 132 | ## implementation according to PIL definition, quite slow 133 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis] 134 | # out = blend(degenerate, img, factor) 135 | # M = ( 136 | # np.eye(3) * factor 137 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor) 138 | # )[np.newaxis, np.newaxis, :] 139 | M = np.float32( 140 | [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]] 141 | ) * factor + np.float32([[0.114], [0.587], [0.299]]) 142 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8) 143 | return out 144 | 145 | 146 | def contrast_func(img, factor): 147 | """ 148 | same output as PIL.ImageEnhance.Contrast 149 | """ 150 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299])) 151 | table = ( 152 | np.array([(el - mean) * factor + mean for el in range(256)]) 153 | .clip(0, 255) 154 | .astype(np.uint8) 155 | ) 156 | out = table[img] 157 | return out 158 | 159 | 160 | def brightness_func(img, factor): 161 | """ 162 | same output as PIL.ImageEnhance.Contrast 163 | """ 164 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8) 165 | out = table[img] 166 | return out 167 | 168 | 169 | def sharpness_func(img, factor): 170 | """ 171 | The differences the this result and PIL are all on the 4 boundaries, the center 172 | areas are same 173 | """ 174 | kernel = np.ones((3, 3), dtype=np.float32) 175 | kernel[1][1] = 5 176 | kernel /= 13 177 | degenerate = cv2.filter2D(img, -1, kernel) 178 | if factor == 0.0: 179 | out = degenerate 180 | elif factor == 1.0: 181 | out = img 182 | else: 183 | out = img.astype(np.float32) 184 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :] 185 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate) 186 | out = out.astype(np.uint8) 187 | return out 188 | 189 | 190 | def shear_x_func(img, factor, fill=(0, 0, 0)): 191 | H, W = img.shape[0], img.shape[1] 192 | M = np.float32([[1, factor, 0], [0, 1, 0]]) 193 | out = cv2.warpAffine( 194 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 195 | ).astype(np.uint8) 196 | return out 197 | 198 | 199 | def translate_x_func(img, offset, fill=(0, 0, 0)): 200 | """ 201 | same output as PIL.Image.transform 202 | """ 203 | H, W = img.shape[0], img.shape[1] 204 | M = np.float32([[1, 0, -offset], [0, 1, 0]]) 205 | out = cv2.warpAffine( 206 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 207 | ).astype(np.uint8) 208 | return out 209 | 210 | 211 | def translate_y_func(img, offset, fill=(0, 0, 0)): 212 | """ 213 | same output as PIL.Image.transform 214 | """ 215 | H, W = img.shape[0], img.shape[1] 216 | M = np.float32([[1, 0, 0], [0, 1, -offset]]) 217 | out = cv2.warpAffine( 218 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 219 | ).astype(np.uint8) 220 | return out 221 | 222 | 223 | def posterize_func(img, bits): 224 | """ 225 | same output as PIL.ImageOps.posterize 226 | """ 227 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits))) 228 | return out 229 | 230 | 231 | def shear_y_func(img, factor, fill=(0, 0, 0)): 232 | H, W = img.shape[0], img.shape[1] 233 | M = np.float32([[1, 0, 0], [factor, 1, 0]]) 234 | out = cv2.warpAffine( 235 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR 236 | ).astype(np.uint8) 237 | return out 238 | 239 | 240 | def cutout_func(img, pad_size, replace=(0, 0, 0)): 241 | replace = np.array(replace, dtype=np.uint8) 242 | H, W = img.shape[0], img.shape[1] 243 | rh, rw = np.random.random(2) 244 | pad_size = pad_size // 2 245 | ch, cw = int(rh * H), int(rw * W) 246 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H) 247 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W) 248 | out = img.copy() 249 | out[x1:x2, y1:y2, :] = replace 250 | return out 251 | 252 | 253 | ### level to args 254 | def enhance_level_to_args(MAX_LEVEL): 255 | def level_to_args(level): 256 | return ((level / MAX_LEVEL) * 1.8 + 0.1,) 257 | 258 | return level_to_args 259 | 260 | 261 | def shear_level_to_args(MAX_LEVEL, replace_value): 262 | def level_to_args(level): 263 | level = (level / MAX_LEVEL) * 0.3 264 | if np.random.random() > 0.5: 265 | level = -level 266 | return (level, replace_value) 267 | 268 | return level_to_args 269 | 270 | 271 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value): 272 | def level_to_args(level): 273 | level = (level / MAX_LEVEL) * float(translate_const) 274 | if np.random.random() > 0.5: 275 | level = -level 276 | return (level, replace_value) 277 | 278 | return level_to_args 279 | 280 | 281 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value): 282 | def level_to_args(level): 283 | level = int((level / MAX_LEVEL) * cutout_const) 284 | return (level, replace_value) 285 | 286 | return level_to_args 287 | 288 | 289 | def solarize_level_to_args(MAX_LEVEL): 290 | def level_to_args(level): 291 | level = int((level / MAX_LEVEL) * 256) 292 | return (level,) 293 | 294 | return level_to_args 295 | 296 | 297 | def none_level_to_args(level): 298 | return () 299 | 300 | 301 | def posterize_level_to_args(MAX_LEVEL): 302 | def level_to_args(level): 303 | level = int((level / MAX_LEVEL) * 4) 304 | return (level,) 305 | 306 | return level_to_args 307 | 308 | 309 | def rotate_level_to_args(MAX_LEVEL, replace_value): 310 | def level_to_args(level): 311 | level = (level / MAX_LEVEL) * 30 312 | if np.random.random() < 0.5: 313 | level = -level 314 | return (level, replace_value) 315 | 316 | return level_to_args 317 | 318 | 319 | func_dict = { 320 | "Identity": identity_func, 321 | "AutoContrast": autocontrast_func, 322 | "Equalize": equalize_func, 323 | "Rotate": rotate_func, 324 | "Solarize": solarize_func, 325 | "Color": color_func, 326 | "Contrast": contrast_func, 327 | "Brightness": brightness_func, 328 | "Sharpness": sharpness_func, 329 | "ShearX": shear_x_func, 330 | "TranslateX": translate_x_func, 331 | "TranslateY": translate_y_func, 332 | "Posterize": posterize_func, 333 | "ShearY": shear_y_func, 334 | } 335 | 336 | translate_const = 10 337 | MAX_LEVEL = 10 338 | replace_value = (128, 128, 128) 339 | arg_dict = { 340 | "Identity": none_level_to_args, 341 | "AutoContrast": none_level_to_args, 342 | "Equalize": none_level_to_args, 343 | "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value), 344 | "Solarize": solarize_level_to_args(MAX_LEVEL), 345 | "Color": enhance_level_to_args(MAX_LEVEL), 346 | "Contrast": enhance_level_to_args(MAX_LEVEL), 347 | "Brightness": enhance_level_to_args(MAX_LEVEL), 348 | "Sharpness": enhance_level_to_args(MAX_LEVEL), 349 | "ShearX": shear_level_to_args(MAX_LEVEL, replace_value), 350 | "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), 351 | "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value), 352 | "Posterize": posterize_level_to_args(MAX_LEVEL), 353 | "ShearY": shear_level_to_args(MAX_LEVEL, replace_value), 354 | } 355 | 356 | 357 | class RandomAugment(object): 358 | def __init__(self, N=2, M=10, isPIL=False, augs=[]): 359 | self.N = N 360 | self.M = M 361 | self.isPIL = isPIL 362 | if augs: 363 | self.augs = augs 364 | else: 365 | self.augs = list(arg_dict.keys()) 366 | 367 | def get_random_ops(self): 368 | sampled_ops = np.random.choice(self.augs, self.N) 369 | return [(op, 0.5, self.M) for op in sampled_ops] 370 | 371 | def __call__(self, img): 372 | if self.isPIL: 373 | img = np.array(img) 374 | ops = self.get_random_ops() 375 | for name, prob, level in ops: 376 | if np.random.random() > prob: 377 | continue 378 | args = arg_dict[name](level) 379 | img = func_dict[name](img, *args) 380 | return img 381 | 382 | 383 | class VideoRandomAugment(object): 384 | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]): 385 | self.N = N 386 | self.M = M 387 | self.p = p 388 | self.tensor_in_tensor_out = tensor_in_tensor_out 389 | if augs: 390 | self.augs = augs 391 | else: 392 | self.augs = list(arg_dict.keys()) 393 | 394 | def get_random_ops(self): 395 | sampled_ops = np.random.choice(self.augs, self.N, replace=False) 396 | return [(op, self.M) for op in sampled_ops] 397 | 398 | def __call__(self, frames): 399 | assert ( 400 | frames.shape[-1] == 3 401 | ), "Expecting last dimension for 3-channels RGB (b, h, w, c)." 402 | 403 | if self.tensor_in_tensor_out: 404 | frames = frames.numpy().astype(np.uint8) 405 | 406 | num_frames = frames.shape[0] 407 | 408 | ops = num_frames * [self.get_random_ops()] 409 | apply_or_not = num_frames * [np.random.random(size=self.N) > self.p] 410 | 411 | frames = torch.stack( 412 | list(map(self._aug, frames, ops, apply_or_not)), dim=0 413 | ).float() 414 | 415 | return frames 416 | 417 | def _aug(self, img, ops, apply_or_not): 418 | for i, (name, level) in enumerate(ops): 419 | if not apply_or_not[i]: 420 | continue 421 | args = arg_dict[name](level) 422 | img = func_dict[name](img, *args) 423 | return torch.from_numpy(img) 424 | 425 | 426 | # if __name__ == "__main__": 427 | # a = RandomAugment() 428 | # img = np.random.randn(32, 32, 3) 429 | # a(img) 430 | 431 | 432 | 433 | 434 | 435 | 436 | class BlipImageTrainProcessor(BlipImageBaseProcessor): 437 | def __init__( 438 | self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0 439 | ): 440 | super().__init__(mean=mean, std=std) 441 | 442 | self.transform = transforms.Compose( 443 | [ 444 | transforms.RandomResizedCrop( 445 | image_size, 446 | scale=(min_scale, max_scale), 447 | interpolation=InterpolationMode.BICUBIC, 448 | ), 449 | # transforms.RandomHorizontalFlip(), 450 | RandomAugment( 451 | 2, 452 | 5, 453 | isPIL=True, 454 | augs=[ 455 | "Identity", 456 | # "AutoContrast", 457 | "Brightness", 458 | "Sharpness", 459 | "Equalize", 460 | # "ShearX", 461 | # "ShearY", 462 | # "TranslateX", 463 | # "TranslateY", 464 | # "Rotate", 465 | ], 466 | ), 467 | transforms.ToTensor(), 468 | self.normalize, 469 | ] 470 | ) 471 | 472 | def __call__(self, item): 473 | return self.transform(item) 474 | 475 | 476 | class BlipImageEvalProcessor(BlipImageBaseProcessor): 477 | def __init__(self, image_size=384, mean=None, std=None): 478 | super().__init__(mean=mean, std=std) 479 | 480 | self.transform = transforms.Compose( 481 | [ 482 | transforms.Resize( 483 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC 484 | ), 485 | transforms.ToTensor(), 486 | self.normalize, 487 | ] 488 | ) 489 | 490 | def __call__(self, item): 491 | return self.transform(item) 492 | 493 | 494 | # if __name__ == "__main__": 495 | # a = BlipImageTrainProcessor(image_size=1024) 496 | # # img = np.random.randn(1024, 1024, 3) 497 | # # x = torch.zeros(1024, 1024, 3) 498 | # x = Image.open("/data/codes/vary-main/log/serve_images/2023-05-23/a2a783d89ede819cdeae943a2199ad3d.jpg").convert("RGB") 499 | # print(x.size) 500 | # y = a(x) 501 | 502 | # print(y.size()) 503 | -------------------------------------------------------------------------------- /Vary-master/vary/model/plug/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | This source code is licensed under the MIT license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | # Implements image augmentation 8 | 9 | import albumentations as alb 10 | from albumentations.pytorch import ToTensorV2 11 | import cv2 12 | import numpy as np 13 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 14 | 15 | 16 | def alb_wrapper(transform): 17 | def f(im): 18 | return transform(image=np.asarray(im))["image"] 19 | 20 | return f 21 | 22 | 23 | class Erosion(alb.ImageOnlyTransform): 24 | """ 25 | Apply erosion operation to an image. 26 | 27 | Erosion is a morphological operation that shrinks the white regions in a binary image. 28 | 29 | Args: 30 | scale (int or tuple/list of int): The scale or range for the size of the erosion kernel. 31 | If an integer is provided, a square kernel of that size will be used. 32 | If a tuple or list is provided, it should contain two integers representing the minimum 33 | and maximum sizes for the erosion kernel. 34 | always_apply (bool, optional): Whether to always apply this transformation. Default is False. 35 | p (float, optional): The probability of applying this transformation. Default is 0.5. 36 | 37 | Returns: 38 | numpy.ndarray: The transformed image. 39 | """ 40 | 41 | def __init__(self, scale, always_apply=False, p=0.5): 42 | super().__init__(always_apply=always_apply, p=p) 43 | if type(scale) is tuple or type(scale) is list: 44 | assert len(scale) == 2 45 | self.scale = scale 46 | else: 47 | self.scale = (scale, scale) 48 | 49 | def apply(self, img, **params): 50 | kernel = cv2.getStructuringElement( 51 | cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2)) 52 | ) 53 | img = cv2.erode(img, kernel, iterations=1) 54 | return img 55 | 56 | 57 | class Dilation(alb.ImageOnlyTransform): 58 | """ 59 | Apply dilation operation to an image. 60 | 61 | Dilation is a morphological operation that expands the white regions in a binary image. 62 | 63 | Args: 64 | scale (int or tuple/list of int): The scale or range for the size of the dilation kernel. 65 | If an integer is provided, a square kernel of that size will be used. 66 | If a tuple or list is provided, it should contain two integers representing the minimum 67 | and maximum sizes for the dilation kernel. 68 | always_apply (bool, optional): Whether to always apply this transformation. Default is False. 69 | p (float, optional): The probability of applying this transformation. Default is 0.5. 70 | 71 | Returns: 72 | numpy.ndarray: The transformed image. 73 | """ 74 | 75 | def __init__(self, scale, always_apply=False, p=0.5): 76 | super().__init__(always_apply=always_apply, p=p) 77 | if type(scale) is tuple or type(scale) is list: 78 | assert len(scale) == 2 79 | self.scale = scale 80 | else: 81 | self.scale = (scale, scale) 82 | 83 | def apply(self, img, **params): 84 | kernel = cv2.getStructuringElement( 85 | cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2)) 86 | ) 87 | img = cv2.dilate(img, kernel, iterations=1) 88 | return img 89 | 90 | 91 | class Bitmap(alb.ImageOnlyTransform): 92 | """ 93 | Apply a bitmap-style transformation to an image. 94 | 95 | This transformation replaces all pixel values below a certain threshold with a specified value. 96 | 97 | Args: 98 | value (int, optional): The value to replace pixels below the threshold with. Default is 0. 99 | lower (int, optional): The threshold value below which pixels will be replaced. Default is 200. 100 | always_apply (bool, optional): Whether to always apply this transformation. Default is False. 101 | p (float, optional): The probability of applying this transformation. Default is 0.5. 102 | 103 | Returns: 104 | numpy.ndarray: The transformed image. 105 | """ 106 | 107 | def __init__(self, value=0, lower=200, always_apply=False, p=0.5): 108 | super().__init__(always_apply=always_apply, p=p) 109 | self.lower = lower 110 | self.value = value 111 | 112 | def apply(self, img, **params): 113 | img = img.copy() 114 | img[img < self.lower] = self.value 115 | return img 116 | 117 | 118 | train_transform = alb_wrapper( 119 | alb.Compose( 120 | [ 121 | Bitmap(p=0), 122 | alb.OneOf([Erosion((2, 3)), Dilation((2, 3))], p=0.02), 123 | alb.Affine(shear={"x": (0, 3), "y": (-3, 0)}, cval=(255, 255, 255), p=0.03), 124 | alb.ShiftScaleRotate( 125 | shift_limit_x=(0, 0.04), 126 | shift_limit_y=(0, 0.03), 127 | scale_limit=(-0.15, 0.03), 128 | rotate_limit=2, 129 | border_mode=0, 130 | interpolation=2, 131 | value=(255, 255, 255), 132 | p=0.03, 133 | ), 134 | alb.GridDistortion( 135 | distort_limit=0.05, 136 | border_mode=0, 137 | interpolation=2, 138 | value=(255, 255, 255), 139 | p=0.04, 140 | ), 141 | alb.Compose( 142 | [ 143 | alb.Affine( 144 | translate_px=(0, 5), always_apply=True, cval=(255, 255, 255) 145 | ), 146 | alb.ElasticTransform( 147 | p=1, 148 | alpha=50, 149 | sigma=120 * 0.1, 150 | alpha_affine=120 * 0.01, 151 | border_mode=0, 152 | value=(255, 255, 255), 153 | ), 154 | ], 155 | p=0.04, 156 | ), 157 | alb.RandomBrightnessContrast(0.1, 0.1, True, p=0.03), 158 | alb.ImageCompression(95, p=0.07), 159 | alb.GaussNoise(20, p=0.08), 160 | alb.GaussianBlur((3, 3), p=0.03), 161 | alb.Resize(1024, 1024), 162 | alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), 163 | ToTensorV2(), 164 | ] 165 | ) 166 | ) 167 | test_transform = alb_wrapper( 168 | alb.Compose( 169 | [ 170 | alb.Resize(1024, 1024), 171 | alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), 172 | ToTensorV2(), 173 | ] 174 | ) 175 | ) 176 | 177 | 178 | 179 | 180 | # if __name__ == '__main__': 181 | # from PIL import Image 182 | # image = Image.open('/data/hypertext/ucaswei/codes_new/show/49.jpg').convert('RGB') 183 | # # image = np.array(image) 184 | # for i in range(100): 185 | # image1 = train_transform(image) 186 | # image1 = Image.fromarray(np.uint8(image1)) 187 | # image1.save('/data/hypertext/ucaswei/codes_new/aug/' + str(i) + '.jpg') 188 | # mm_projector_1 = nn.Linear(1024, 256) 189 | 190 | # x = torch.zeros(2, 3, 1024, 1024) 191 | 192 | # with torch.no_grad(): 193 | # y = model(x) 194 | # print(y.shape) 195 | # y = mm_projector_1(y.permute(0,2,1)) 196 | # print(y.shape) 197 | # print(y.permute(0,2,1).shape) -------------------------------------------------------------------------------- /Vary-master/vary/model/vary_opt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.nn import CrossEntropyLoss 22 | 23 | from transformers import AutoConfig, AutoModelForCausalLM, \ 24 | LlamaConfig, LlamaModel, LlamaForCausalLM, \ 25 | CLIPVisionModel, CLIPImageProcessor 26 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast 27 | 28 | from vary.utils.constants import * 29 | 30 | from vary.model.plug.blip_process import BlipImageEvalProcessor 31 | 32 | from vary.model.vision_encoder.sam import build_sam_vit_b 33 | 34 | 35 | from transformers import OPTConfig, OPTModel, OPTForCausalLM 36 | 37 | from vary.model.plug.transforms import train_transform, test_transform 38 | 39 | 40 | 41 | class varyConfig(OPTConfig): 42 | model_type = "vary" 43 | 44 | 45 | class varyOPTModel(OPTModel): 46 | config_class = varyConfig 47 | 48 | def __init__(self, config: OPTConfig): 49 | super(varyOPTModel, self).__init__(config) 50 | 51 | 52 | self.vision_tower = build_sam_vit_b() 53 | 54 | self.mm_projector = nn.Linear(1024, 768) 55 | 56 | def initialize_vision_modules( 57 | self, 58 | vision_tower, 59 | pretrained_stage1_model=None, 60 | freeze_vision_tower=False, 61 | use_im_start_end=False, 62 | vision_select_layer=-1, 63 | dtype=torch.float16, 64 | device="cuda" 65 | ): 66 | 67 | # 224*224 68 | # image_processor do not used in opt 69 | image_processor = CLIPImageProcessor.from_pretrained('/cache/vit-large-patch14') 70 | # 1024*1024 71 | 72 | image_processor_high = train_transform 73 | 74 | image_token_len = 256 75 | 76 | self.config.vision_tower = vision_tower 77 | self.config.image_token_len = image_token_len 78 | self.config.use_im_start_end = True 79 | 80 | self.config.vision_select_layer = vision_select_layer 81 | self.config.freeze_vision_tower = freeze_vision_tower 82 | 83 | return dict( 84 | image_processor=image_processor, 85 | image_processor_high=image_processor_high, 86 | image_token_len=image_token_len, 87 | ) 88 | 89 | def embed_tokens(self, x): 90 | return self.get_input_embeddings()(x) 91 | 92 | def forward( 93 | self, 94 | input_ids: torch.LongTensor = None, 95 | attention_mask: Optional[torch.Tensor] = None, 96 | past_key_values: Optional[List[torch.FloatTensor]] = None, 97 | inputs_embeds: Optional[torch.FloatTensor] = None, 98 | use_cache: Optional[bool] = None, 99 | output_attentions: Optional[bool] = None, 100 | output_hidden_states: Optional[bool] = None, 101 | images: Optional[torch.FloatTensor] = None, 102 | return_dict: Optional[bool] = None, 103 | ) -> Union[Tuple, BaseModelOutputWithPast]: 104 | 105 | # HACK: replace back original embeddings for LLaVA pretraining 106 | # orig_embeds_params = getattr(self, 'orig_embeds_params', None) 107 | # if orig_embeds_params is not None: 108 | # with torch.no_grad(): 109 | # self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data 110 | 111 | if inputs_embeds is None: 112 | inputs_embeds = self.embed_tokens(input_ids) 113 | # inputs_embeds = self.wte(input_ids) 114 | 115 | 116 | vision_tower = getattr(self, 'vision_tower', None) 117 | 118 | 119 | if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: 120 | 121 | use_im_start_end = getattr(self.config, "use_im_start_end", -1) 122 | 123 | vision_select_layer = getattr(self.config, "vision_select_layer", -1) 124 | im_patch_token = getattr(self.config, "im_patch_token", -1) 125 | im_start_token = getattr(self.config, "im_start_token", -1) 126 | im_end_token = getattr(self.config, "im_end_token", -1) 127 | freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False) 128 | 129 | 130 | image_features = [] 131 | for image in images: 132 | 133 | with torch.set_grad_enabled(True): 134 | cnn_feature = vision_tower(image[1]) 135 | cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) 136 | image_feature_final = cnn_feature 137 | 138 | image_features.append(image_feature_final) 139 | 140 | if type(images) is list: 141 | image_features = [self.mm_projector(image_feature) for image_feature in image_features] 142 | else: 143 | # image_features = self.mm_projector(image_features) 144 | raise NotImplementedError 145 | 146 | # dummy_image_features = torch.zeros(1024, 1664, device=inputs_embeds.device, dtype=inputs_embeds.dtype).permute(0, 2, 1).reshape(dummy_image_features.shape[0], -1, 32, 32) 147 | # VIT 1024; CNN:1024 148 | dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) 149 | dummy_image_features = self.mm_projector(dummy_image_features) 150 | 151 | use_im_start_end = True 152 | new_input_embeds = [] 153 | for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features): 154 | if (cur_input_ids == im_patch_token).sum() == 0: 155 | # multimodal LLM, but the current sample is not multimodal 156 | cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() 157 | new_input_embeds.append(cur_input_embeds) 158 | continue 159 | 160 | if use_im_start_end: 161 | if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum(): 162 | raise ValueError("The number of image start tokens and image end tokens should be the same.") 163 | 164 | image_start_tokens = torch.where(cur_input_ids == im_start_token)[0] 165 | for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features): 166 | per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device) 167 | num_patches = per_cur_image_features.shape[0] 168 | # print(cur_input_ids) 169 | if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token: 170 | raise ValueError("The image end token should follow the image start token.") 171 | 172 | cur_input_embeds = torch.cat( 173 | ( 174 | cur_input_embeds[:image_start_token_pos+1], 175 | per_cur_image_features, 176 | cur_input_embeds[image_start_token_pos + num_patches + 1:] 177 | ), 178 | dim=0 179 | ) 180 | 181 | 182 | new_input_embeds.append(cur_input_embeds) 183 | else: 184 | raise NotImplementedError 185 | 186 | inputs_embeds = torch.stack(new_input_embeds, dim=0) 187 | 188 | 189 | return super(varyOPTModel, self).forward( 190 | input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, 191 | inputs_embeds=inputs_embeds, use_cache=use_cache, 192 | output_attentions=output_attentions, output_hidden_states=output_hidden_states, 193 | return_dict=return_dict 194 | ) 195 | 196 | 197 | class varyOPTForCausalLM(OPTForCausalLM): 198 | config_class = varyConfig 199 | # supports_gradient_checkpointing = True 200 | 201 | def __init__(self, config): 202 | super(OPTForCausalLM, self).__init__(config) 203 | self.model = varyOPTModel(config) 204 | 205 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 206 | 207 | # Initialize weights and apply final processing 208 | self.post_init() 209 | 210 | def get_model(self): 211 | return self.model 212 | 213 | # def _set_gradient_checkpointing(self, module, value=False): 214 | # if isinstance(module, varyQwenModel): 215 | # module.gradient_checkpointing = value 216 | 217 | 218 | 219 | def forward( 220 | self, 221 | input_ids: Optional[torch.LongTensor] = None, 222 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 223 | attention_mask: Optional[torch.FloatTensor] = None, 224 | token_type_ids: Optional[torch.LongTensor] = None, 225 | position_ids: Optional[torch.LongTensor] = None, 226 | head_mask: Optional[torch.FloatTensor] = None, 227 | inputs_embeds: Optional[torch.FloatTensor] = None, 228 | encoder_hidden_states: Optional[torch.Tensor] = None, 229 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 230 | labels: Optional[torch.LongTensor] = None, 231 | use_cache: Optional[bool] = None, 232 | output_attentions: Optional[bool] = None, 233 | output_hidden_states: Optional[bool] = None, 234 | images: Optional[torch.FloatTensor] = None, 235 | return_dict: Optional[bool] = None, 236 | 237 | ) -> Union[Tuple, CausalLMOutputWithPast]: 238 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 239 | output_hidden_states = ( 240 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 241 | ) 242 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 243 | 244 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 245 | # print(input_ids) 246 | # print(len(images)) 247 | 248 | outputs = self.model( 249 | input_ids=input_ids, 250 | past_key_values=past_key_values, 251 | attention_mask=attention_mask, 252 | inputs_embeds=inputs_embeds, 253 | use_cache=use_cache, 254 | output_attentions=output_attentions, 255 | output_hidden_states=output_hidden_states, 256 | images=images, 257 | return_dict=return_dict 258 | 259 | ) 260 | 261 | 262 | 263 | 264 | hidden_states = outputs[0] 265 | logits = self.lm_head(hidden_states).contiguous() 266 | 267 | # logits 268 | 269 | loss = None 270 | if labels is not None: 271 | # move labels to correct device to enable model parallelism 272 | labels = labels.to(logits.device) 273 | # Shift so that tokens < n predict n 274 | shift_logits = logits[..., :-1, :].contiguous() 275 | shift_labels = labels[..., 1:].contiguous() 276 | # Flatten the tokens 277 | loss_fct = CrossEntropyLoss() 278 | loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) 279 | 280 | if not return_dict: 281 | output = (logits,) + outputs[1:] 282 | return (loss,) + output if loss is not None else output 283 | 284 | return CausalLMOutputWithPast( 285 | loss=loss, 286 | logits=logits, 287 | past_key_values=outputs.past_key_values, 288 | hidden_states=outputs.hidden_states, 289 | attentions=outputs.attentions, 290 | ) 291 | 292 | 293 | def prepare_inputs_for_generation( 294 | self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs 295 | ): 296 | token_type_ids = kwargs.get("token_type_ids", None) 297 | if past_key_values: 298 | input_ids = input_ids[:, -1].unsqueeze(-1) 299 | if token_type_ids is not None: 300 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 301 | 302 | attention_mask = kwargs.get("attention_mask", None) 303 | position_ids = kwargs.get("position_ids", None) 304 | 305 | if attention_mask is not None and position_ids is None: 306 | position_ids = attention_mask.long().cumsum(-1) - 1 307 | position_ids.masked_fill_(attention_mask == 0, 1) 308 | if past_key_values: 309 | position_ids = position_ids[:, -1].unsqueeze(-1) 310 | else: 311 | position_ids = None 312 | 313 | if inputs_embeds is not None and past_key_values is None: 314 | model_inputs = {"inputs_embeds": inputs_embeds} 315 | else: 316 | model_inputs = {"input_ids": input_ids} 317 | 318 | model_inputs.update( 319 | { 320 | "past_key_values": past_key_values, 321 | "use_cache": kwargs.get("use_cache"), 322 | "position_ids": position_ids, 323 | "attention_mask": attention_mask, 324 | "token_type_ids": token_type_ids, 325 | "images": kwargs.get("images", None), 326 | } 327 | ) 328 | return model_inputs 329 | 330 | def initialize_vision_tokenizer( 331 | self, 332 | tokenizer, 333 | freeze_lm_model=False, 334 | pretrained_stage1_model=None, 335 | device="cuda" 336 | ): 337 | config = self.get_model().config 338 | 339 | # add image patch token 340 | tokenizer.add_tokens("", special_tokens=True) 341 | self.resize_token_embeddings(len(tokenizer)) 342 | 343 | tokenizer.add_tokens(DEFAULT_IMAGE_PATCH_TOKEN, special_tokens=True) 344 | self.resize_token_embeddings(len(tokenizer)) 345 | config.im_patch_token = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_PATCH_TOKEN) 346 | 347 | 348 | config.use_im_start_end = True 349 | 350 | # add image start token and end token 351 | if config.use_im_start_end: 352 | num_new_tokens = 2 353 | tokenizer.add_tokens(DEFAULT_IM_START_TOKEN , special_tokens=True) 354 | tokenizer.add_tokens(DEFAULT_IM_END_TOKEN , special_tokens=True) 355 | self.resize_token_embeddings(len(tokenizer)) 356 | config.im_start_token = tokenizer.convert_tokens_to_ids(DEFAULT_IM_START_TOKEN) 357 | config.im_end_token = tokenizer.convert_tokens_to_ids(DEFAULT_IM_END_TOKEN) 358 | 359 | # config.im_start_token, config.im_end_token = 151857, 151858 360 | 361 | if num_new_tokens > 0: 362 | input_embeddings = self.get_input_embeddings().weight.data 363 | output_embeddings = self.get_output_embeddings().weight.data 364 | 365 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 366 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 367 | 368 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 369 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 370 | 371 | 372 | 373 | AutoConfig.register("vary", varyConfig) 374 | AutoModelForCausalLM.register(varyConfig, varyOPTForCausalLM) 375 | -------------------------------------------------------------------------------- /Vary-master/vary/model/vary_qwen_vary.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import List, Optional, Tuple, Union 17 | 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | from torch.nn import CrossEntropyLoss 22 | 23 | from transformers import AutoConfig, AutoModelForCausalLM, \ 24 | CLIPVisionModel, CLIPImageProcessor 25 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast 26 | 27 | from vary.utils.constants import * 28 | 29 | from vary.model.plug.blip_process import BlipImageEvalProcessor 30 | 31 | from vary.model.llm.qwen.modeling_qwen import QWenLMHeadModel, QWenModel 32 | 33 | from vary.model.llm.qwen.configuration_qwen import QWenConfig 34 | from vary.model.vision_encoder.sam import build_sam_vit_b 35 | from vary.model.plug.transforms import train_transform, test_transform 36 | 37 | 38 | class varyConfig(QWenConfig): 39 | model_type = "vary" 40 | 41 | 42 | class varyQwenModel(QWenModel): 43 | config_class = varyConfig 44 | 45 | def __init__(self, config: QWenConfig): 46 | super(varyQwenModel, self).__init__(config) 47 | # TODO download the clip-vit in huggingface 48 | self.vision_tower = CLIPVisionModel.from_pretrained('/cache/vit-large-patch14/') 49 | 50 | self.vision_tower_high = build_sam_vit_b() # build_sam_vit_b(checkpoint = 'xxxx') for train 51 | 52 | self.mm_projector = nn.Linear(1024, 2048) 53 | self.mm_projector_vary = nn.Linear(1024, 2048) 54 | 55 | def initialize_vision_modules( 56 | self, 57 | vision_tower, 58 | pretrained_stage1_model=None, 59 | freeze_vision_tower=False, 60 | use_im_start_end=False, 61 | vision_select_layer=-1, 62 | dtype=torch.float16, 63 | device="cuda" 64 | ): 65 | 66 | # 224*224 67 | # TODO download the clip-vit in huggingface 68 | image_processor = CLIPImageProcessor.from_pretrained('/cache/vit-large-patch14/') 69 | # 1024*1024 70 | image_processor_high = train_transform 71 | 72 | self.vision_tower = self.vision_tower.to(dtype=dtype, device=device) 73 | 74 | self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device) 75 | 76 | self.mm_projector = self.mm_projector.to(dtype=dtype, device=device) 77 | self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device) 78 | 79 | image_token_len = 256 80 | 81 | self.config.vision_tower = vision_tower 82 | self.config.image_token_len = image_token_len 83 | 84 | self.config.use_im_start_end = True 85 | 86 | self.config.vision_select_layer = vision_select_layer 87 | self.config.freeze_vision_tower = freeze_vision_tower 88 | 89 | return dict( 90 | image_processor=image_processor, 91 | image_processor_high=image_processor_high, 92 | image_token_len=image_token_len, 93 | 94 | ) 95 | 96 | def embed_tokens(self, x): 97 | return self.wte(x) 98 | 99 | def forward( 100 | self, 101 | input_ids: torch.LongTensor = None, 102 | attention_mask: Optional[torch.Tensor] = None, 103 | past_key_values: Optional[List[torch.FloatTensor]] = None, 104 | inputs_embeds: Optional[torch.FloatTensor] = None, 105 | use_cache: Optional[bool] = None, 106 | output_attentions: Optional[bool] = None, 107 | output_hidden_states: Optional[bool] = None, 108 | images: Optional[torch.FloatTensor] = None, 109 | return_dict: Optional[bool] = None, 110 | ) -> Union[Tuple, BaseModelOutputWithPast]: 111 | 112 | # HACK: replace back original embeddings for LLaVA pretraining 113 | # orig_embeds_params = getattr(self, 'orig_embeds_params', None) 114 | # if orig_embeds_params is not None: 115 | # with torch.no_grad(): 116 | # self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data 117 | 118 | if inputs_embeds is None: 119 | inputs_embeds = self.embed_tokens(input_ids) 120 | # inputs_embeds = self.wte(input_ids) 121 | 122 | 123 | vision_tower = getattr(self, 'vision_tower', None) 124 | vision_tower_high = getattr(self, 'vision_tower_high', None) 125 | 126 | 127 | if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: 128 | 129 | use_im_start_end = getattr(self.config, "use_im_start_end", -1) 130 | 131 | vision_select_layer = getattr(self.config, "vision_select_layer", -1) 132 | # im_patch_token = getattr(self.config, "im_patch_token", -1) 133 | # im_start_token = getattr(self.config, "im_start_token", -1) 134 | # im_end_token = getattr(self.config, "im_end_token", -1) 135 | # freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False) 136 | 137 | im_patch_token = 151859 138 | 139 | im_start_token = 151857 140 | 141 | im_end_token = 151858 142 | 143 | 144 | image_features_1 = [] 145 | image_features_2 = [] 146 | for image in images: 147 | 148 | with torch.set_grad_enabled(False): 149 | image_forward_out = vision_tower(image[0], output_hidden_states=True) 150 | select_hidden_state = image_forward_out.hidden_states[vision_select_layer] 151 | image_feature = select_hidden_state[:, 1:] # 256*1024 152 | with torch.set_grad_enabled(False): 153 | cnn_feature = vision_tower_high(image[1]) 154 | cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024 155 | 156 | image_features_1.append(image_feature) 157 | image_features_2.append(cnn_feature) 158 | 159 | 160 | if type(images) is list: 161 | image_features_1 = [self.mm_projector(image_feature) for image_feature in image_features_1] 162 | image_features_2 = [self.mm_projector_vary(image_feature) for image_feature in image_features_2] 163 | image_features = [torch.cat((image_feature[0], image_feature[1]), dim=-1) for image_feature in zip(image_features_1, image_features_2)] 164 | else: 165 | 166 | raise NotImplementedError 167 | 168 | 169 | # dummy_image_features = torch.zeros(256, 4096, device=inputs_embeds.device, dtype=inputs_embeds.dtype) 170 | dummy_image_features_1 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) 171 | dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) 172 | dummy_image_features_1 = self.mm_projector(dummy_image_features_1) 173 | dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2) 174 | dummy_image_features = torch.cat((dummy_image_features_1, dummy_image_features_2), dim=-1) 175 | use_im_start_end = True 176 | new_input_embeds = [] 177 | for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features): 178 | if (cur_input_ids == im_patch_token).sum() == 0: 179 | # multimodal LLM, but the current sample is not multimodal 180 | cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() 181 | new_input_embeds.append(cur_input_embeds) 182 | continue 183 | 184 | if use_im_start_end: 185 | if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum(): 186 | raise ValueError("The number of image start tokens and image end tokens should be the same.") 187 | 188 | image_start_tokens = torch.where(cur_input_ids == im_start_token)[0] 189 | for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features): 190 | per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device) 191 | num_patches = per_cur_image_features.shape[0] 192 | 193 | if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token: 194 | raise ValueError("The image end token should follow the image start token.") 195 | 196 | # if orig_embeds_params is not None: 197 | # cur_new_input_embeds = torch.cat( 198 | # ( 199 | # cur_input_embeds[:image_start_token_pos].detach(), 200 | # cur_input_embeds[image_start_token_pos:image_start_token_pos+1], 201 | # per_cur_image_features, 202 | # cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2], 203 | # cur_input_embeds[image_start_token_pos + num_patches + 2:].detach() 204 | # ), 205 | # dim=0 206 | # ) 207 | # else: 208 | cur_input_embeds = torch.cat( 209 | ( 210 | cur_input_embeds[:image_start_token_pos+1], 211 | per_cur_image_features, 212 | cur_input_embeds[image_start_token_pos + num_patches + 1:] 213 | ), 214 | dim=0 215 | ) 216 | 217 | 218 | new_input_embeds.append(cur_input_embeds) 219 | else: 220 | raise NotImplementedError 221 | 222 | inputs_embeds = torch.stack(new_input_embeds, dim=0) 223 | 224 | return super(varyQwenModel, self).forward( 225 | input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, 226 | inputs_embeds=inputs_embeds, use_cache=use_cache, 227 | output_attentions=output_attentions, output_hidden_states=output_hidden_states, 228 | return_dict=return_dict 229 | ) 230 | 231 | 232 | class varyQwenForCausalLM(QWenLMHeadModel): 233 | config_class = varyConfig 234 | # supports_gradient_checkpointing = True 235 | 236 | def __init__(self, config): 237 | super(QWenLMHeadModel, self).__init__(config) 238 | self.transformer = varyQwenModel(config) 239 | 240 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 241 | 242 | # Initialize weights and apply final processing 243 | self.post_init() 244 | 245 | def get_model(self): 246 | return self.transformer 247 | 248 | # def _set_gradient_checkpointing(self, module, value=False): 249 | # if isinstance(module, varyQwenModel): 250 | # module.gradient_checkpointing = value 251 | 252 | def forward( 253 | self, 254 | input_ids: Optional[torch.LongTensor] = None, 255 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 256 | attention_mask: Optional[torch.FloatTensor] = None, 257 | token_type_ids: Optional[torch.LongTensor] = None, 258 | position_ids: Optional[torch.LongTensor] = None, 259 | head_mask: Optional[torch.FloatTensor] = None, 260 | inputs_embeds: Optional[torch.FloatTensor] = None, 261 | encoder_hidden_states: Optional[torch.Tensor] = None, 262 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 263 | labels: Optional[torch.LongTensor] = None, 264 | use_cache: Optional[bool] = None, 265 | output_attentions: Optional[bool] = None, 266 | output_hidden_states: Optional[bool] = None, 267 | images: Optional[torch.FloatTensor] = None, 268 | return_dict: Optional[bool] = None, 269 | 270 | ) -> Union[Tuple, CausalLMOutputWithPast]: 271 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 272 | output_hidden_states = ( 273 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 274 | ) 275 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 276 | 277 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 278 | 279 | 280 | transformer_outputs = self.transformer( 281 | input_ids=input_ids, 282 | past_key_values=past_key_values, 283 | attention_mask=attention_mask, 284 | inputs_embeds=inputs_embeds, 285 | use_cache=use_cache, 286 | output_attentions=output_attentions, 287 | output_hidden_states=output_hidden_states, 288 | images=images, 289 | return_dict=return_dict 290 | 291 | ) 292 | 293 | 294 | 295 | hidden_states = transformer_outputs[0] 296 | lm_logits = self.lm_head(hidden_states) 297 | 298 | # logits 299 | 300 | loss = None 301 | if labels is not None: 302 | labels = labels.to(lm_logits.device) 303 | shift_logits = lm_logits[..., :-1, :].contiguous() 304 | shift_labels = labels[..., 1:].contiguous() 305 | loss_fct = CrossEntropyLoss() 306 | loss = loss_fct( 307 | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) 308 | ) 309 | 310 | if not return_dict: 311 | output = (lm_logits,) + transformer_outputs[1:] 312 | return ((loss,) + output) if loss is not None else output 313 | 314 | # print(loss) 315 | 316 | if not return_dict: 317 | output = (lm_logits,) + transformer_outputs[1:] 318 | return ((loss,) + output) if loss is not None else output 319 | 320 | return CausalLMOutputWithPast( 321 | loss=loss, 322 | logits=lm_logits, 323 | past_key_values=transformer_outputs.past_key_values, 324 | hidden_states=transformer_outputs.hidden_states, 325 | attentions=transformer_outputs.attentions, 326 | ) 327 | 328 | def prepare_inputs_for_generation( 329 | self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs 330 | ): 331 | token_type_ids = kwargs.get("token_type_ids", None) 332 | if past_key_values: 333 | input_ids = input_ids[:, -1].unsqueeze(-1) 334 | if token_type_ids is not None: 335 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 336 | 337 | attention_mask = kwargs.get("attention_mask", None) 338 | position_ids = kwargs.get("position_ids", None) 339 | 340 | if attention_mask is not None and position_ids is None: 341 | position_ids = attention_mask.long().cumsum(-1) - 1 342 | position_ids.masked_fill_(attention_mask == 0, 1) 343 | if past_key_values: 344 | position_ids = position_ids[:, -1].unsqueeze(-1) 345 | else: 346 | position_ids = None 347 | 348 | if inputs_embeds is not None and past_key_values is None: 349 | model_inputs = {"inputs_embeds": inputs_embeds} 350 | else: 351 | model_inputs = {"input_ids": input_ids} 352 | 353 | model_inputs.update( 354 | { 355 | "past_key_values": past_key_values, 356 | "use_cache": kwargs.get("use_cache"), 357 | "position_ids": position_ids, 358 | "attention_mask": attention_mask, 359 | "token_type_ids": token_type_ids, 360 | "images": kwargs.get("images", None), 361 | } 362 | ) 363 | return model_inputs 364 | 365 | def initialize_vision_tokenizer( 366 | self, 367 | tokenizer, 368 | freeze_lm_model=False, 369 | pretrained_stage1_model=None, 370 | device="cuda" 371 | ): 372 | config = self.get_model().config 373 | 374 | 375 | self.resize_token_embeddings(len(tokenizer)) 376 | 377 | 378 | config.im_patch_token = 151859 379 | 380 | config.use_im_start_end = True 381 | 382 | # add image start token and end token 383 | if config.use_im_start_end: 384 | # num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 385 | self.resize_token_embeddings(len(tokenizer)) 386 | # config.im_start_token, config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]) 387 | 388 | config.im_start_token, config.im_end_token = 151857, 151858 389 | 390 | 391 | AutoConfig.register("vary", varyConfig) 392 | AutoModelForCausalLM.register(varyConfig, varyQwenForCausalLM) 393 | -------------------------------------------------------------------------------- /Vary-master/vary/model/vary_toy_qwen1_8.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch.nn import CrossEntropyLoss 6 | from transformers import AutoConfig, AutoModelForCausalLM, \ 7 | CLIPVisionModel, CLIPImageProcessor 8 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast 9 | from vary.utils.constants import * 10 | from vary.model.plug.blip_process import BlipImageEvalProcessor 11 | from vary.model.llm.qwen.modeling_qwen import QWenLMHeadModel, QWenModel 12 | from vary.model.llm.qwen.configuration_qwen import QWenConfig 13 | from vary.model.vision_encoder.sam import build_sam_vit_b 14 | 15 | 16 | class varyConfig(QWenConfig): 17 | model_type = "vary" 18 | 19 | 20 | class varyQwenModel(QWenModel): 21 | config_class = varyConfig 22 | 23 | def __init__(self, config: QWenConfig): 24 | super(varyQwenModel, self).__init__(config) 25 | 26 | self.vision_tower = CLIPVisionModel.from_pretrained('/data/hypertext/ucaswei/cache/vit-large-patch14/vit-large-patch14/') 27 | self.vision_tower_high = build_sam_vit_b() 28 | 29 | self.mm_projector = nn.Linear(1024, 1024) 30 | self.mm_projector_vary = nn.Linear(1024, 1024) 31 | 32 | def initialize_vision_modules( 33 | self, 34 | vision_tower, 35 | pretrained_stage1_model=None, 36 | freeze_vision_tower=False, 37 | use_im_start_end=False, 38 | vision_select_layer=-1, 39 | dtype=torch.float16, 40 | device="cuda" 41 | ): 42 | 43 | # 224*224 44 | image_processor = CLIPImageProcessor.from_pretrained('/data/hypertext/ucaswei/cache/vit-large-patch14/vit-large-patch14/') 45 | # 1024*1024 46 | image_processor_high = BlipImageEvalProcessor(image_size=1024) 47 | 48 | self.vision_tower = self.vision_tower.to(dtype=dtype, device=device) 49 | self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device) 50 | 51 | self.mm_projector = self.mm_projector.to(dtype=dtype, device=device) 52 | self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device) 53 | 54 | image_token_len = 256 55 | 56 | self.config.vision_tower = vision_tower 57 | self.config.image_token_len = image_token_len 58 | self.config.use_im_start_end = True 59 | self.config.vision_select_layer = vision_select_layer 60 | self.config.freeze_vision_tower = freeze_vision_tower 61 | 62 | return dict( 63 | image_processor=image_processor, 64 | image_processor_high=image_processor_high, 65 | image_token_len=image_token_len, 66 | # vision_config=vision_config 67 | ) 68 | 69 | def embed_tokens(self, x): 70 | return self.wte(x) 71 | 72 | def forward( 73 | self, 74 | input_ids: torch.LongTensor = None, 75 | attention_mask: Optional[torch.Tensor] = None, 76 | past_key_values: Optional[List[torch.FloatTensor]] = None, 77 | inputs_embeds: Optional[torch.FloatTensor] = None, 78 | use_cache: Optional[bool] = None, 79 | output_attentions: Optional[bool] = None, 80 | output_hidden_states: Optional[bool] = None, 81 | images: Optional[torch.FloatTensor] = None, 82 | return_dict: Optional[bool] = None, 83 | ) -> Union[Tuple, BaseModelOutputWithPast]: 84 | 85 | 86 | 87 | if inputs_embeds is None: 88 | inputs_embeds = self.embed_tokens(input_ids) 89 | 90 | 91 | vision_tower = getattr(self, 'vision_tower', None) 92 | vision_tower_high = getattr(self, 'vision_tower_high', None) 93 | 94 | 95 | if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None: 96 | 97 | use_im_start_end = getattr(self.config, "use_im_start_end", -1) 98 | 99 | vision_select_layer = getattr(self.config, "vision_select_layer", -1) 100 | im_patch_token = getattr(self.config, "im_patch_token", -1) 101 | im_start_token = getattr(self.config, "im_start_token", -1) 102 | im_end_token = getattr(self.config, "im_end_token", -1) 103 | freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False) 104 | 105 | im_patch_token = 151859 106 | im_start_token = 151857 107 | im_end_token = 151858 108 | 109 | image_features = [] 110 | image_features_1 = [] 111 | image_features_2 = [] 112 | for image in images: 113 | with torch.set_grad_enabled(False): 114 | image_forward_out = vision_tower(image[0], output_hidden_states=True) 115 | select_hidden_state = image_forward_out.hidden_states[vision_select_layer] 116 | image_feature = select_hidden_state[:, 1:] # 256*1024 117 | with torch.set_grad_enabled(False): 118 | cnn_feature = vision_tower_high(image[1]) 119 | cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024 120 | image_features_1.append(image_feature) 121 | image_features_2.append(cnn_feature) 122 | 123 | 124 | if type(images) is list: 125 | image_features_1 = [self.mm_projector(image_feature) for image_feature in image_features_1] 126 | image_features_2 = [self.mm_projector_vary(image_feature) for image_feature in image_features_2] 127 | image_features = [torch.cat((image_feature[0], image_feature[1]), dim=-1) for image_feature in zip(image_features_1, image_features_2)] 128 | else: 129 | raise NotImplementedError 130 | 131 | dummy_image_features_1 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) 132 | dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype) 133 | dummy_image_features_1 = self.mm_projector(dummy_image_features_1) 134 | dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2) 135 | dummy_image_features = torch.cat((dummy_image_features_1, dummy_image_features_2), dim=-1) 136 | use_im_start_end = True 137 | new_input_embeds = [] 138 | for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features): 139 | if (cur_input_ids == im_patch_token).sum() == 0: 140 | # multimodal LLM, but the current sample is not multimodal 141 | cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum() 142 | new_input_embeds.append(cur_input_embeds) 143 | continue 144 | 145 | if use_im_start_end: 146 | if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum(): 147 | raise ValueError("The number of image start tokens and image end tokens should be the same.") 148 | 149 | image_start_tokens = torch.where(cur_input_ids == im_start_token)[0] 150 | for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features): 151 | per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device) 152 | num_patches = per_cur_image_features.shape[0] 153 | 154 | if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token: 155 | raise ValueError("The image end token should follow the image start token.") 156 | 157 | cur_input_embeds = torch.cat( 158 | ( 159 | cur_input_embeds[:image_start_token_pos+1], 160 | per_cur_image_features, 161 | cur_input_embeds[image_start_token_pos + num_patches + 1:] 162 | ), 163 | dim=0 164 | ) 165 | 166 | 167 | new_input_embeds.append(cur_input_embeds) 168 | else: 169 | raise NotImplementedError 170 | 171 | inputs_embeds = torch.stack(new_input_embeds, dim=0) 172 | 173 | return super(varyQwenModel, self).forward( 174 | input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values, 175 | inputs_embeds=inputs_embeds, use_cache=use_cache, 176 | output_attentions=output_attentions, output_hidden_states=output_hidden_states, 177 | return_dict=return_dict 178 | ) 179 | 180 | 181 | class varyQwenForCausalLM(QWenLMHeadModel): 182 | config_class = varyConfig 183 | # supports_gradient_checkpointing = True 184 | 185 | def __init__(self, config): 186 | super(QWenLMHeadModel, self).__init__(config) 187 | self.transformer = varyQwenModel(config) 188 | 189 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 190 | 191 | # Initialize weights and apply final processing 192 | self.post_init() 193 | 194 | def get_model(self): 195 | return self.transformer 196 | 197 | 198 | def forward( 199 | self, 200 | input_ids: Optional[torch.LongTensor] = None, 201 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, 202 | attention_mask: Optional[torch.FloatTensor] = None, 203 | token_type_ids: Optional[torch.LongTensor] = None, 204 | position_ids: Optional[torch.LongTensor] = None, 205 | head_mask: Optional[torch.FloatTensor] = None, 206 | inputs_embeds: Optional[torch.FloatTensor] = None, 207 | encoder_hidden_states: Optional[torch.Tensor] = None, 208 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 209 | labels: Optional[torch.LongTensor] = None, 210 | use_cache: Optional[bool] = None, 211 | output_attentions: Optional[bool] = None, 212 | output_hidden_states: Optional[bool] = None, 213 | images: Optional[torch.FloatTensor] = None, 214 | return_dict: Optional[bool] = None, 215 | 216 | ) -> Union[Tuple, CausalLMOutputWithPast]: 217 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 218 | output_hidden_states = ( 219 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 220 | ) 221 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 222 | 223 | 224 | transformer_outputs = self.transformer( 225 | input_ids=input_ids, 226 | past_key_values=past_key_values, 227 | attention_mask=attention_mask, 228 | inputs_embeds=inputs_embeds, 229 | use_cache=use_cache, 230 | output_attentions=output_attentions, 231 | output_hidden_states=output_hidden_states, 232 | images=images, 233 | return_dict=return_dict 234 | 235 | ) 236 | 237 | hidden_states = transformer_outputs[0] 238 | lm_logits = self.lm_head(hidden_states) 239 | 240 | # logits 241 | 242 | loss = None 243 | if labels is not None: 244 | labels = labels.to(lm_logits.device) 245 | shift_logits = lm_logits[..., :-1, :].contiguous() 246 | shift_labels = labels[..., 1:].contiguous() 247 | loss_fct = CrossEntropyLoss() 248 | loss = loss_fct( 249 | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) 250 | ) 251 | 252 | if not return_dict: 253 | output = (lm_logits,) + transformer_outputs[1:] 254 | return ((loss,) + output) if loss is not None else output 255 | 256 | if not return_dict: 257 | output = (lm_logits,) + transformer_outputs[1:] 258 | return ((loss,) + output) if loss is not None else output 259 | 260 | return CausalLMOutputWithPast( 261 | loss=loss, 262 | logits=lm_logits, 263 | past_key_values=transformer_outputs.past_key_values, 264 | hidden_states=transformer_outputs.hidden_states, 265 | attentions=transformer_outputs.attentions, 266 | ) 267 | 268 | def prepare_inputs_for_generation( 269 | self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs 270 | ): 271 | token_type_ids = kwargs.get("token_type_ids", None) 272 | if past_key_values: 273 | input_ids = input_ids[:, -1].unsqueeze(-1) 274 | if token_type_ids is not None: 275 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1) 276 | 277 | attention_mask = kwargs.get("attention_mask", None) 278 | position_ids = kwargs.get("position_ids", None) 279 | 280 | if attention_mask is not None and position_ids is None: 281 | position_ids = attention_mask.long().cumsum(-1) - 1 282 | position_ids.masked_fill_(attention_mask == 0, 1) 283 | if past_key_values: 284 | position_ids = position_ids[:, -1].unsqueeze(-1) 285 | else: 286 | position_ids = None 287 | 288 | if inputs_embeds is not None and past_key_values is None: 289 | model_inputs = {"inputs_embeds": inputs_embeds} 290 | else: 291 | model_inputs = {"input_ids": input_ids} 292 | 293 | model_inputs.update( 294 | { 295 | "past_key_values": past_key_values, 296 | "use_cache": kwargs.get("use_cache"), 297 | "position_ids": position_ids, 298 | "attention_mask": attention_mask, 299 | "token_type_ids": token_type_ids, 300 | "images": kwargs.get("images", None), 301 | } 302 | ) 303 | return model_inputs 304 | 305 | def initialize_vision_tokenizer( 306 | self, 307 | tokenizer, 308 | freeze_lm_model=False, 309 | pretrained_stage1_model=None, 310 | device="cuda" 311 | ): 312 | config = self.get_model().config 313 | 314 | self.resize_token_embeddings(len(tokenizer)) 315 | 316 | config.im_patch_token = 151859 317 | 318 | config.use_im_start_end = True 319 | 320 | if config.use_im_start_end: 321 | self.resize_token_embeddings(len(tokenizer)) 322 | 323 | config.im_start_token, config.im_end_token = 151857, 151858 324 | 325 | 326 | 327 | AutoConfig.register("vary", varyConfig) 328 | AutoModelForCausalLM.register(varyConfig, varyQwenForCausalLM) 329 | -------------------------------------------------------------------------------- /Vary-master/vary/model/vision_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Vary-master/vary/model/vision_encoder/sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | from typing import Optional, Tuple, Type 12 | 13 | from functools import partial 14 | 15 | import torch 16 | import torch.nn as nn 17 | 18 | from typing import Type 19 | 20 | import math 21 | 22 | 23 | 24 | 25 | class MLPBlock(nn.Module): 26 | def __init__( 27 | self, 28 | embedding_dim: int, 29 | mlp_dim: int, 30 | act: Type[nn.Module] = nn.GELU, 31 | ) -> None: 32 | super().__init__() 33 | self.lin1 = nn.Linear(embedding_dim, mlp_dim) 34 | self.lin2 = nn.Linear(mlp_dim, embedding_dim) 35 | self.act = act() 36 | 37 | def forward(self, x: torch.Tensor) -> torch.Tensor: 38 | return self.lin2(self.act(self.lin1(x))) 39 | 40 | 41 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 42 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 43 | class LayerNorm2d(nn.Module): 44 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 45 | super().__init__() 46 | self.weight = nn.Parameter(torch.ones(num_channels)) 47 | self.bias = nn.Parameter(torch.zeros(num_channels)) 48 | self.eps = eps 49 | 50 | def forward(self, x: torch.Tensor) -> torch.Tensor: 51 | u = x.mean(1, keepdim=True) 52 | s = (x - u).pow(2).mean(1, keepdim=True) 53 | x = (x - u) / torch.sqrt(s + self.eps) 54 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 55 | return x 56 | 57 | 58 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa 59 | class ImageEncoderViT(nn.Module): 60 | def __init__( 61 | self, 62 | img_size: int = 1024, 63 | patch_size: int = 16, 64 | in_chans: int = 3, 65 | embed_dim: int = 768, 66 | depth: int = 12, 67 | num_heads: int = 12, 68 | mlp_ratio: float = 4.0, 69 | out_chans: int = 256, 70 | qkv_bias: bool = True, 71 | norm_layer: Type[nn.Module] = nn.LayerNorm, 72 | act_layer: Type[nn.Module] = nn.GELU, 73 | use_abs_pos: bool = True, 74 | use_rel_pos: bool = False, 75 | rel_pos_zero_init: bool = True, 76 | window_size: int = 0, 77 | global_attn_indexes: Tuple[int, ...] = (), 78 | ) -> None: 79 | """ 80 | Args: 81 | img_size (int): Input image size. 82 | patch_size (int): Patch size. 83 | in_chans (int): Number of input image channels. 84 | embed_dim (int): Patch embedding dimension. 85 | depth (int): Depth of ViT. 86 | num_heads (int): Number of attention heads in each ViT block. 87 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 88 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 89 | norm_layer (nn.Module): Normalization layer. 90 | act_layer (nn.Module): Activation layer. 91 | use_abs_pos (bool): If True, use absolute positional embeddings. 92 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 93 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 94 | window_size (int): Window size for window attention blocks. 95 | global_attn_indexes (list): Indexes for blocks using global attention. 96 | """ 97 | super().__init__() 98 | self.img_size = img_size 99 | 100 | self.patch_embed = PatchEmbed( 101 | kernel_size=(patch_size, patch_size), 102 | stride=(patch_size, patch_size), 103 | in_chans=in_chans, 104 | embed_dim=embed_dim, 105 | ) 106 | 107 | self.pos_embed: Optional[nn.Parameter] = None 108 | if use_abs_pos: 109 | # Initialize absolute positional embedding with pretrain image size. 110 | self.pos_embed = nn.Parameter( 111 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) 112 | ) 113 | 114 | self.blocks = nn.ModuleList() 115 | for i in range(depth): 116 | block = Block( 117 | dim=embed_dim, 118 | num_heads=num_heads, 119 | mlp_ratio=mlp_ratio, 120 | qkv_bias=qkv_bias, 121 | norm_layer=norm_layer, 122 | act_layer=act_layer, 123 | use_rel_pos=use_rel_pos, 124 | rel_pos_zero_init=rel_pos_zero_init, 125 | window_size=window_size if i not in global_attn_indexes else 0, 126 | input_size=(img_size // patch_size, img_size // patch_size), 127 | ) 128 | self.blocks.append(block) 129 | 130 | self.neck = nn.Sequential( 131 | nn.Conv2d( 132 | embed_dim, 133 | out_chans, 134 | kernel_size=1, 135 | bias=False, 136 | ), 137 | LayerNorm2d(out_chans), 138 | nn.Conv2d( 139 | out_chans, 140 | out_chans, 141 | kernel_size=3, 142 | padding=1, 143 | bias=False, 144 | ), 145 | LayerNorm2d(out_chans), 146 | ) 147 | 148 | 149 | self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) 150 | self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False) 151 | 152 | def forward(self, x: torch.Tensor) -> torch.Tensor: 153 | x = self.patch_embed(x) 154 | if self.pos_embed is not None: 155 | x = x + self.pos_embed 156 | 157 | for blk in self.blocks: 158 | x = blk(x) 159 | 160 | x = self.neck(x.permute(0, 3, 1, 2)) 161 | x = self.net_2(x) 162 | x = self.net_3(x) 163 | 164 | 165 | return x 166 | 167 | 168 | class Block(nn.Module): 169 | """Transformer blocks with support of window attention and residual propagation blocks""" 170 | 171 | def __init__( 172 | self, 173 | dim: int, 174 | num_heads: int, 175 | mlp_ratio: float = 4.0, 176 | qkv_bias: bool = True, 177 | norm_layer: Type[nn.Module] = nn.LayerNorm, 178 | act_layer: Type[nn.Module] = nn.GELU, 179 | use_rel_pos: bool = False, 180 | rel_pos_zero_init: bool = True, 181 | window_size: int = 0, 182 | input_size: Optional[Tuple[int, int]] = None, 183 | ) -> None: 184 | """ 185 | Args: 186 | dim (int): Number of input channels. 187 | num_heads (int): Number of attention heads in each ViT block. 188 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 189 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 190 | norm_layer (nn.Module): Normalization layer. 191 | act_layer (nn.Module): Activation layer. 192 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map. 193 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 194 | window_size (int): Window size for window attention blocks. If it equals 0, then 195 | use global attention. 196 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 197 | positional parameter size. 198 | """ 199 | super().__init__() 200 | self.norm1 = norm_layer(dim) 201 | self.attn = Attention( 202 | dim, 203 | num_heads=num_heads, 204 | qkv_bias=qkv_bias, 205 | use_rel_pos=use_rel_pos, 206 | rel_pos_zero_init=rel_pos_zero_init, 207 | input_size=input_size if window_size == 0 else (window_size, window_size), 208 | ) 209 | 210 | self.norm2 = norm_layer(dim) 211 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) 212 | 213 | self.window_size = window_size 214 | 215 | def forward(self, x: torch.Tensor) -> torch.Tensor: 216 | shortcut = x 217 | x = self.norm1(x) 218 | # Window partition 219 | if self.window_size > 0: 220 | H, W = x.shape[1], x.shape[2] 221 | x, pad_hw = window_partition(x, self.window_size) 222 | 223 | x = self.attn(x) 224 | # Reverse window partition 225 | if self.window_size > 0: 226 | x = window_unpartition(x, self.window_size, pad_hw, (H, W)) 227 | 228 | x = shortcut + x 229 | x = x + self.mlp(self.norm2(x)) 230 | 231 | return x 232 | 233 | 234 | class Attention(nn.Module): 235 | """Multi-head Attention block with relative position embeddings.""" 236 | 237 | def __init__( 238 | self, 239 | dim: int, 240 | num_heads: int = 8, 241 | qkv_bias: bool = True, 242 | use_rel_pos: bool = False, 243 | rel_pos_zero_init: bool = True, 244 | input_size: Optional[Tuple[int, int]] = None, 245 | ) -> None: 246 | """ 247 | Args: 248 | dim (int): Number of input channels. 249 | num_heads (int): Number of attention heads. 250 | qkv_bias (bool): If True, add a learnable bias to query, key, value. 251 | rel_pos (bool): If True, add relative positional embeddings to the attention map. 252 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. 253 | input_size (tuple(int, int) or None): Input resolution for calculating the relative 254 | positional parameter size. 255 | """ 256 | super().__init__() 257 | self.num_heads = num_heads 258 | head_dim = dim // num_heads 259 | self.scale = head_dim**-0.5 260 | 261 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 262 | self.proj = nn.Linear(dim, dim) 263 | 264 | self.use_rel_pos = use_rel_pos 265 | if self.use_rel_pos: 266 | assert ( 267 | input_size is not None 268 | ), "Input size must be provided if using relative positional encoding." 269 | # initialize relative positional embeddings 270 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) 271 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) 272 | 273 | def forward(self, x: torch.Tensor) -> torch.Tensor: 274 | B, H, W, _ = x.shape 275 | # qkv with shape (3, B, nHead, H * W, C) 276 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) 277 | # q, k, v with shape (B * nHead, H * W, C) 278 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) 279 | 280 | attn = (q * self.scale) @ k.transpose(-2, -1) 281 | 282 | if self.use_rel_pos: 283 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) 284 | 285 | attn = attn.softmax(dim=-1) 286 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) 287 | x = self.proj(x) 288 | 289 | return x 290 | 291 | 292 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: 293 | """ 294 | Partition into non-overlapping windows with padding if needed. 295 | Args: 296 | x (tensor): input tokens with [B, H, W, C]. 297 | window_size (int): window size. 298 | 299 | Returns: 300 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 301 | (Hp, Wp): padded height and width before partition 302 | """ 303 | B, H, W, C = x.shape 304 | 305 | pad_h = (window_size - H % window_size) % window_size 306 | pad_w = (window_size - W % window_size) % window_size 307 | if pad_h > 0 or pad_w > 0: 308 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 309 | Hp, Wp = H + pad_h, W + pad_w 310 | 311 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 312 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 313 | return windows, (Hp, Wp) 314 | 315 | 316 | def window_unpartition( 317 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] 318 | ) -> torch.Tensor: 319 | """ 320 | Window unpartition into original sequences and removing padding. 321 | Args: 322 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 323 | window_size (int): window size. 324 | pad_hw (Tuple): padded height and width (Hp, Wp). 325 | hw (Tuple): original height and width (H, W) before padding. 326 | 327 | Returns: 328 | x: unpartitioned sequences with [B, H, W, C]. 329 | """ 330 | Hp, Wp = pad_hw 331 | H, W = hw 332 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 333 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) 334 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 335 | 336 | if Hp > H or Wp > W: 337 | x = x[:, :H, :W, :].contiguous() 338 | return x 339 | 340 | 341 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: 342 | """ 343 | Get relative positional embeddings according to the relative positions of 344 | query and key sizes. 345 | Args: 346 | q_size (int): size of query q. 347 | k_size (int): size of key k. 348 | rel_pos (Tensor): relative position embeddings (L, C). 349 | 350 | Returns: 351 | Extracted positional embeddings according to relative positions. 352 | """ 353 | max_rel_dist = int(2 * max(q_size, k_size) - 1) 354 | # Interpolate rel pos if needed. 355 | if rel_pos.shape[0] != max_rel_dist: 356 | # Interpolate rel pos. 357 | rel_pos_resized = F.interpolate( 358 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), 359 | size=max_rel_dist, 360 | mode="linear", 361 | ) 362 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) 363 | else: 364 | rel_pos_resized = rel_pos 365 | 366 | # Scale the coords with short length if shapes for q and k are different. 367 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) 368 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) 369 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) 370 | 371 | return rel_pos_resized[relative_coords.long()] 372 | 373 | 374 | def add_decomposed_rel_pos( 375 | attn: torch.Tensor, 376 | q: torch.Tensor, 377 | rel_pos_h: torch.Tensor, 378 | rel_pos_w: torch.Tensor, 379 | q_size: Tuple[int, int], 380 | k_size: Tuple[int, int], 381 | ) -> torch.Tensor: 382 | """ 383 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. 384 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 385 | Args: 386 | attn (Tensor): attention map. 387 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). 388 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. 389 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. 390 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w). 391 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w). 392 | 393 | Returns: 394 | attn (Tensor): attention map with added relative positional embeddings. 395 | """ 396 | q_h, q_w = q_size 397 | k_h, k_w = k_size 398 | Rh = get_rel_pos(q_h, k_h, rel_pos_h) 399 | Rw = get_rel_pos(q_w, k_w, rel_pos_w) 400 | 401 | B, _, dim = q.shape 402 | r_q = q.reshape(B, q_h, q_w, dim) 403 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) 404 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) 405 | 406 | attn = ( 407 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] 408 | ).view(B, q_h * q_w, k_h * k_w) 409 | 410 | return attn 411 | 412 | 413 | class PatchEmbed(nn.Module): 414 | """ 415 | Image to Patch Embedding. 416 | """ 417 | 418 | def __init__( 419 | self, 420 | kernel_size: Tuple[int, int] = (16, 16), 421 | stride: Tuple[int, int] = (16, 16), 422 | padding: Tuple[int, int] = (0, 0), 423 | in_chans: int = 3, 424 | embed_dim: int = 768, 425 | ) -> None: 426 | """ 427 | Args: 428 | kernel_size (Tuple): kernel size of the projection layer. 429 | stride (Tuple): stride of the projection layer. 430 | padding (Tuple): padding size of the projection layer. 431 | in_chans (int): Number of input image channels. 432 | embed_dim (int): Patch embedding dimension. 433 | """ 434 | super().__init__() 435 | 436 | self.proj = nn.Conv2d( 437 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 438 | ) 439 | 440 | def forward(self, x: torch.Tensor) -> torch.Tensor: 441 | x = self.proj(x) 442 | # B C H W -> B H W C 443 | x = x.permute(0, 2, 3, 1) 444 | return x 445 | 446 | 447 | 448 | def build_sam_vit_b(checkpoint=None): 449 | return _build_sam( 450 | encoder_embed_dim=768, 451 | encoder_depth=12, 452 | encoder_num_heads=12, 453 | encoder_global_attn_indexes=[2, 5, 8, 11], 454 | checkpoint=checkpoint, 455 | ) 456 | 457 | 458 | def _build_sam( 459 | encoder_embed_dim, 460 | encoder_depth, 461 | encoder_num_heads, 462 | encoder_global_attn_indexes, 463 | checkpoint=None, 464 | ): 465 | prompt_embed_dim = 256 466 | image_size = 1024 467 | vit_patch_size = 16 468 | image_embedding_size = image_size // vit_patch_size 469 | image_encoder=ImageEncoderViT( 470 | depth=encoder_depth, 471 | embed_dim=encoder_embed_dim, 472 | img_size=image_size, 473 | mlp_ratio=4, 474 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), 475 | num_heads=encoder_num_heads, 476 | patch_size=vit_patch_size, 477 | qkv_bias=True, 478 | use_rel_pos=True, 479 | global_attn_indexes=encoder_global_attn_indexes, 480 | window_size=14, 481 | out_chans=prompt_embed_dim, 482 | ) 483 | 484 | if checkpoint is not None: 485 | # with open(checkpoint, "rb") as f: 486 | state_dict = torch.load(checkpoint) 487 | 488 | 489 | image_encoder.load_state_dict(state_dict, strict=True) 490 | # image_encoder.load_state_dict({k[19:]: v for k, v in state_dict.items() if 'vision_tower' in k}, strict=True) 491 | print(checkpoint) 492 | return image_encoder 493 | 494 | 495 | 496 | 497 | if __name__ == '__main__': 498 | 499 | x = torch.zeros(2, 3, 1024, 1024) 500 | 501 | # x.permute(0, 3, 1, 2) 502 | 503 | net = build_sam_vit_b(checkpoint ='') 504 | 505 | -------------------------------------------------------------------------------- /Vary-master/vary/train/train_flash_attn.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from vary.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | from vary.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /Vary-master/vary/train/train_lora.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ucas-HaoranWei/Vary-toy/c6e405977aa6e88d4807a6b79f394299623fc6f3/Vary-master/vary/train/train_lora.py -------------------------------------------------------------------------------- /Vary-master/vary/train/train_lora_flash_attn.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn. 4 | 5 | # Need to call this before importing transformers. 6 | from vary.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn 7 | 8 | replace_llama_attn_with_flash_attn() 9 | 10 | # from vary.train.train import train 11 | from vary.train.train_lora import train 12 | 13 | if __name__ == "__main__": 14 | train() 15 | -------------------------------------------------------------------------------- /Vary-master/vary/train/train_opt.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import logging 18 | import pathlib 19 | import torch 20 | import transformers 21 | 22 | from vary.train.trainer_vit_fixlr import varyTrainer 23 | from vary.model import * 24 | from vary.data import make_supervised_data_module 25 | from vary.utils.arguments import * 26 | from vary.utils.constants import * 27 | from vary.model.vision_encoder.sam import build_sam_vit_b 28 | 29 | def train(): 30 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 31 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 32 | 33 | 34 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, use_fast=False, padding_side="right", model_max_length=training_args.model_max_length) 35 | 36 | 37 | model = varyOPTForCausalLM.from_pretrained(model_args.model_name_or_path) 38 | 39 | 40 | 41 | dtype = torch.float32 42 | if training_args.fp16: 43 | dtype = torch.float16 44 | if training_args.bf16: 45 | dtype = torch.bfloat16 46 | 47 | vision_tower_dict = model.get_model().initialize_vision_modules( 48 | vision_tower=model_args.vision_tower, 49 | pretrained_stage1_model=model_args.pretrained_stage1_model, 50 | freeze_vision_tower=model_args.freeze_vision_tower, 51 | use_im_start_end=model_args.use_im_start_end, 52 | vision_select_layer=model_args.vision_select_layer, 53 | dtype=dtype, 54 | device=training_args.device 55 | ) 56 | 57 | model.initialize_vision_tokenizer( 58 | tokenizer=tokenizer, 59 | freeze_lm_model=model_args.freeze_lm_model, 60 | pretrained_stage1_model=model_args.pretrained_stage1_model, 61 | device=training_args.device, 62 | ) 63 | 64 | 65 | 66 | model.to(dtype=dtype, device=training_args.device) 67 | 68 | data_args.image_token_len = 256 69 | data_args.image_processor = vision_tower_dict['image_processor'] 70 | data_args.image_processor_high = vision_tower_dict['image_processor_high'] 71 | data_args.use_im_start_end = model_args.use_im_start_end 72 | 73 | # mixed relation, to be fixed 74 | if model_args.freeze_lm_model: 75 | model.requires_grad_(False) 76 | for p in model.get_model().mm_projector.parameters(): 77 | p.requires_grad = True 78 | 79 | for p in model.get_input_embeddings().parameters(): 80 | p.requires_grad = True 81 | 82 | 83 | if not model_args.freeze_vision_tower: 84 | 85 | model.get_model().vision_tower.requires_grad_(True) 86 | 87 | 88 | params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad] 89 | print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M") 90 | 91 | # params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad] 92 | # if len(params_no_grad) > 0: 93 | # if training_args.fsdp is not None and len(training_args.fsdp) > 0: 94 | # if len(params_no_grad) < 10: 95 | # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad)) 96 | # else: 97 | # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10]))) 98 | # print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.") 99 | # print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining") 100 | 101 | # from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP 102 | # def patch_FSDP_use_orig_params(func): 103 | # def wrap_func(*args, **kwargs): 104 | # use_orig_params = kwargs.pop('use_orig_params', True) 105 | # return func(*args, **kwargs, use_orig_params=use_orig_params) 106 | # return wrap_func 107 | 108 | # FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__) 109 | 110 | # interleave = True 111 | data_module = make_supervised_data_module( 112 | interleave=training_args.interleave, 113 | with_box=training_args.with_box, 114 | tokenizer=tokenizer, 115 | data_args=data_args 116 | ) 117 | 118 | trainer = varyTrainer( 119 | model=model, 120 | tokenizer=tokenizer, 121 | args=training_args, 122 | **data_module) 123 | 124 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 125 | trainer.train(resume_from_checkpoint=True) 126 | else: 127 | trainer.train() 128 | trainer.save_state() 129 | trainer._safe_save(output_dir=training_args.output_dir) 130 | 131 | 132 | if __name__ == "__main__": 133 | train() 134 | 135 | -------------------------------------------------------------------------------- /Vary-master/vary/train/train_qwen_vary.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import logging 18 | import pathlib 19 | import torch 20 | import transformers 21 | 22 | 23 | from vary.train.trainer_vit_fixlr import varyTrainer 24 | from vary.model import * 25 | from vary.data import make_supervised_data_module 26 | from vary.utils.arguments import * 27 | from vary.utils.utils import smart_tokenizer_and_embedding_resize 28 | from vary.model.vision_encoder.sam import build_sam_vit_b 29 | 30 | 31 | def train(): 32 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 33 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 34 | 35 | tokenizer = transformers.AutoTokenizer.from_pretrained("model_args.model_name_or_path", trust_remote_code=True, padding_side="right", model_max_length=training_args.model_max_length,) 36 | 37 | 38 | model = varyQwenForCausalLM.from_pretrained(model_args.model_name_or_path, low_cpu_mem_usage=True, device_map='cuda') 39 | 40 | 41 | 42 | smart_tokenizer_and_embedding_resize( 43 | special_tokens_dict=dict(pad_token='<|endoftext|>'), 44 | tokenizer=tokenizer, 45 | model=model, 46 | ) 47 | 48 | 49 | dtype = torch.float32 50 | if training_args.fp16: 51 | dtype = torch.float16 52 | if training_args.bf16: 53 | dtype = torch.bfloat16 54 | 55 | vision_tower_dict = model.get_model().initialize_vision_modules( 56 | vision_tower=model_args.vision_tower, 57 | pretrained_stage1_model=model_args.pretrained_stage1_model, 58 | freeze_vision_tower=model_args.freeze_vision_tower, 59 | use_im_start_end=model_args.use_im_start_end, 60 | vision_select_layer=model_args.vision_select_layer, 61 | dtype=dtype, 62 | device=training_args.device 63 | ) 64 | 65 | model.initialize_vision_tokenizer( 66 | tokenizer=tokenizer, 67 | freeze_lm_model=model_args.freeze_lm_model, 68 | pretrained_stage1_model=model_args.pretrained_stage1_model, 69 | device=training_args.device, 70 | ) 71 | 72 | 73 | 74 | 75 | model.to(dtype=dtype, device=training_args.device) 76 | 77 | data_args.image_token_len = 256 78 | data_args.image_processor = vision_tower_dict['image_processor'] 79 | data_args.image_processor_high = vision_tower_dict['image_processor_high'] 80 | data_args.use_im_start_end = model_args.use_im_start_end 81 | 82 | # mixed relation, to be fixed 83 | if model_args.freeze_lm_model: 84 | model.requires_grad_(False) 85 | for p in model.get_model().mm_projector.parameters(): 86 | p.requires_grad = True 87 | for p in model.get_model().mm_projector_vary.parameters(): 88 | p.requires_grad = True 89 | for p in model.get_input_embeddings().parameters(): 90 | p.requires_grad = True 91 | 92 | 93 | if not model_args.freeze_vision_tower: 94 | model.get_model().vision_tower.requires_grad_(True) 95 | model.get_model().vision_tower_high.requires_grad_(True) 96 | 97 | 98 | params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad] 99 | print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M") 100 | 101 | # params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad] 102 | # if len(params_no_grad) > 0: 103 | # if training_args.fsdp is not None and len(training_args.fsdp) > 0: 104 | # if len(params_no_grad) < 10: 105 | # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad)) 106 | # else: 107 | # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10]))) 108 | # print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.") 109 | # print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining") 110 | 111 | # from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP 112 | # def patch_FSDP_use_orig_params(func): 113 | # def wrap_func(*args, **kwargs): 114 | # use_orig_params = kwargs.pop('use_orig_params', True) 115 | # return func(*args, **kwargs, use_orig_params=use_orig_params) 116 | # return wrap_func 117 | 118 | # FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__) 119 | 120 | 121 | 122 | data_module = make_supervised_data_module( 123 | interleave=training_args.interleave, 124 | with_box=training_args.with_box, 125 | tokenizer=tokenizer, 126 | data_args=data_args 127 | ) 128 | 129 | trainer = varyTrainer( 130 | model=model, 131 | tokenizer=tokenizer, 132 | args=training_args, 133 | **data_module) 134 | 135 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 136 | trainer.train(resume_from_checkpoint=True) 137 | else: 138 | trainer.train() 139 | trainer.save_state() 140 | trainer._safe_save(output_dir=training_args.output_dir) 141 | 142 | 143 | if __name__ == "__main__": 144 | train() 145 | -------------------------------------------------------------------------------- /Vary-master/vary/train/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from transformers import Trainer 6 | from typing import Dict, Optional, Sequence 7 | 8 | 9 | def unwrap_model(model: nn.Module) -> nn.Module: 10 | """ 11 | Recursively unwraps a model from potential containers (as used in distributed training). 12 | 13 | Args: 14 | model (`torch.nn.Module`): The model to unwrap. 15 | """ 16 | # since there could be multiple levels of wrapping, unwrap recursively 17 | if hasattr(model, "module"): 18 | return unwrap_model(model.module) 19 | else: 20 | return model 21 | 22 | 23 | class varyTrainer(Trainer): 24 | 25 | def _safe_save(self, output_dir: str): 26 | """Collects the state dict and dump to disk.""" 27 | if self.deepspeed: 28 | torch.cuda.synchronize() 29 | self.save_model(output_dir) 30 | return 31 | 32 | state_dict = self.model.state_dict() 33 | if self.args.should_save: 34 | cpu_state_dict = { 35 | key: value.cpu() 36 | for key, value in state_dict.items() 37 | } 38 | del state_dict 39 | self._save(output_dir, state_dict=cpu_state_dict) # noqa 40 | 41 | 42 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 43 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 44 | # Save the model 45 | _state_dict = state_dict 46 | if _state_dict is None: 47 | # Only save the model itself if we are using distributed training 48 | model_to_save = unwrap_model(self.model) 49 | _state_dict = model_to_save.state_dict() 50 | 51 | weight_to_save = {} 52 | keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in'] 53 | for k, v in _state_dict.items(): 54 | if any(key_match in k for key_match in keys_to_match): 55 | weight_to_save[k] = v 56 | 57 | current_folder = output_dir.split('/')[-1] 58 | parent_folder = os.path.dirname(output_dir) 59 | if current_folder.startswith('checkpoint-'): 60 | mm_projector_folder = os.path.join(parent_folder, "mm_projector") 61 | os.makedirs(mm_projector_folder, exist_ok=True) 62 | torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) 63 | else: 64 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 65 | 66 | super(varyTrainer, self)._save(output_dir, state_dict) 67 | -------------------------------------------------------------------------------- /Vary-master/vary/train/trainer_vit_fixlr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | from transformers import Trainer 6 | from transformers.trainer_pt_utils import get_parameter_names 7 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS 8 | from typing import Dict, Optional, Sequence 9 | 10 | 11 | def unwrap_model(model: nn.Module) -> nn.Module: 12 | """ 13 | Recursively unwraps a model from potential containers (as used in distributed training). 14 | 15 | Args: 16 | model (`torch.nn.Module`): The model to unwrap. 17 | """ 18 | # since there could be multiple levels of wrapping, unwrap recursively 19 | if hasattr(model, "module"): 20 | return unwrap_model(model.module) 21 | else: 22 | return model 23 | 24 | 25 | class varyTrainer(Trainer): 26 | 27 | def _safe_save(self, output_dir: str): 28 | """Collects the state dict and dump to disk.""" 29 | state_dict = self.model.state_dict() 30 | if self.args.should_save: 31 | cpu_state_dict = { 32 | key: value.cpu() 33 | for key, value in state_dict.items() 34 | } 35 | del state_dict 36 | self._save(output_dir, state_dict=cpu_state_dict) # noqa 37 | 38 | 39 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 40 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 41 | # Save the model 42 | _state_dict = state_dict 43 | if _state_dict is None: 44 | # Only save the model itself if we are using distributed training 45 | model_to_save = unwrap_model(self.model) 46 | _state_dict = model_to_save.state_dict() 47 | 48 | weight_to_save = {} 49 | keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in'] 50 | for k, v in _state_dict.items(): 51 | if any(key_match in k for key_match in keys_to_match): 52 | weight_to_save[k] = v 53 | 54 | current_folder = output_dir.split('/')[-1] 55 | parent_folder = os.path.dirname(output_dir) 56 | if current_folder.startswith('checkpoint-'): 57 | mm_projector_folder = os.path.join(parent_folder, "mm_projector") 58 | os.makedirs(mm_projector_folder, exist_ok=True) 59 | torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) 60 | else: 61 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 62 | 63 | super(varyTrainer, self)._save(output_dir, state_dict) 64 | 65 | def create_optimizer(self): 66 | """ 67 | Setup the optimizer. 68 | 69 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 70 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 71 | """ 72 | opt_model = self.model 73 | 74 | if self.optimizer is None: 75 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 76 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 77 | optimizer_grouped_parameters = [ 78 | { 79 | "params": [ 80 | p for n, p in opt_model.named_parameters() if 'vision_encoder' in n and n in decay_parameters and p.requires_grad 81 | ], 82 | "weight_decay": self.args.weight_decay, 83 | "lr": self.args.learning_rate, 84 | }, 85 | { 86 | "params": [ 87 | p for n, p in opt_model.named_parameters() if 'vision_encoder' in n and n not in decay_parameters and p.requires_grad], 88 | "weight_decay": 0.0, 89 | "lr": self.args.learning_rate, 90 | }, 91 | { 92 | "params": [ 93 | p for n, p in opt_model.named_parameters() if 'vision_encoder' not in n and n in decay_parameters and p.requires_grad], 94 | "weight_decay": self.args.weight_decay, 95 | "lr": self.args.learning_rate, 96 | }, 97 | { 98 | "params": [ 99 | p for n, p in opt_model.named_parameters() if 'vision_encoder' not in n and n not in decay_parameters and p.requires_grad 100 | ], 101 | "weight_decay": 0.0, 102 | "lr": self.args.learning_rate, 103 | }, 104 | ] 105 | for idx, group in enumerate(optimizer_grouped_parameters): 106 | print(idx, len(group['params']), group['lr']) 107 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 108 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 109 | 110 | return self.optimizer -------------------------------------------------------------------------------- /Vary-master/vary/utils/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Dict, Optional, Sequence 3 | import transformers 4 | 5 | 6 | @dataclass 7 | class ModelArguments: 8 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 9 | use_cache: bool = field(default=False) 10 | vision_tower: Optional[str] = field(default="~/.cache/huggingface/hub/models--openai--clip-vit-large-patch14/snapshots/8d052a0f05efbaefbc9e8786ba291cfdf93e5bff/") 11 | freeze_vision_tower: bool = field(default=False) 12 | freeze_lm_model: bool = field(default=False) 13 | pretrained_stage1_model: Optional[str] = field(default=None) # mlp &/ vision tower 14 | vision_select_layer: Optional[int] = field(default=-1) # default to the last layer 15 | use_im_start_end: bool = field(default=False) 16 | 17 | 18 | @dataclass 19 | class DataArguments: 20 | datasets: str = field(default=None, metadata={"help": "combinations of the training data."}) 21 | sep_image_conv_front: bool = False 22 | image_token_len: int = 256 23 | image_aspect_ratio: str = 'square' 24 | conversation_version: str = 'mpt' 25 | # conversation_version: str = 'v0' 26 | # conversation_version: str = 'v1' 27 | # conversation_version: str = 'opt' 28 | box_limit: int = 0 29 | 30 | 31 | @dataclass 32 | class TrainingArguments(transformers.TrainingArguments): 33 | cache_dir: Optional[str] = field(default=None) 34 | optim: str = field(default="adamw_torch") 35 | remove_unused_columns: bool = field(default=False) 36 | force_fsdp: bool = field(default=False) 37 | interleave: bool = field(default=False) 38 | with_box: bool = field(default=False) 39 | model_max_length: int = field( 40 | default=512, 41 | metadata={ 42 | "help": 43 | "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 44 | }, 45 | ) 46 | lora_enable: bool = False 47 | lora_r: int = 8 48 | lora_alpha: int = 16 49 | lora_dropout: float = 0.05 50 | lora_weight_path: str = "" 51 | lora_bias: str = "none" -------------------------------------------------------------------------------- /Vary-master/vary/utils/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "log" 5 | 6 | IGNORE_INDEX = -100 7 | # DEFAULT_PAD_TOKEN = "[PAD]" 8 | 9 | DEFAULT_PAD_TOKEN = "<|endoftext|>" 10 | DEFAULT_EOS_TOKEN = "" 11 | DEFAULT_BOS_TOKEN = "" 12 | DEFAULT_UNK_TOKEN = "" 13 | DEFAULT_IMAGE_TOKEN = "" 14 | DEFAULT_BOX_TOKEN = "" 15 | 16 | DEFAULT_IMAGE_PATCH_TOKEN = '' 17 | 18 | DEFAULT_IM_START_TOKEN = '' 19 | DEFAULT_IM_END_TOKEN = '' 20 | 21 | 22 | ROOT_PATH = '/data/public/ucaswei/data/' 23 | 24 | CONVERSATION_DATA = { 25 | 26 | # pair 4m 27 | 'laion-coco-4m': { 28 | 'images': '', 29 | 'annotations': '', 30 | }, 31 | 32 | 'cc665k': { 33 | 'images': "/path_to/LLaVA1.5/images/", 34 | 'annotations': "/path_to/LLaVA1.5/llava_v1_5_66k.json", 35 | }, 36 | 37 | 'pdf': { 38 | 'images': "", 39 | 'annotations': "", 40 | }, 41 | 42 | 'docvqa_train': { 43 | 'images': "", 44 | 'annotations': "", 45 | }, 46 | 47 | 'chartqa_train': { 48 | 'images': "", 49 | 'annotations': "", 50 | }, 51 | 52 | 53 | 54 | } -------------------------------------------------------------------------------- /Vary-master/vary/utils/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from enum import auto, Enum 3 | from typing import List, Tuple 4 | 5 | 6 | class SeparatorStyle(Enum): 7 | """Different separator style.""" 8 | SINGLE = auto() 9 | TWO = auto() 10 | MPT = auto() 11 | 12 | 13 | 14 | 15 | @dataclasses.dataclass 16 | class Conversation: 17 | """A class that keeps all conversation history.""" 18 | system: str 19 | roles: List[str] 20 | messages: List[List[str]] 21 | offset: int 22 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE 23 | sep: str = "<|im_end|>" 24 | sep2: str = None 25 | version: str = "Unknown" 26 | 27 | skip_next: bool = False 28 | 29 | def get_prompt(self): 30 | if self.sep_style == SeparatorStyle.SINGLE: 31 | ret = self.system + self.sep + '\n' 32 | for role, message in self.messages: 33 | if message: 34 | if type(message) is tuple: 35 | message, _, _ = message 36 | ret += role + ": " + message + self.sep 37 | else: 38 | ret += role + ":" 39 | return ret 40 | elif self.sep_style == SeparatorStyle.TWO: 41 | seps = [self.sep, self.sep2] 42 | ret = self.system + seps[0] 43 | for i, (role, message) in enumerate(self.messages): 44 | if message: 45 | if type(message) is tuple: 46 | message, _, _ = message 47 | ret += role + ": " + message + seps[i % 2] 48 | else: 49 | ret += role + ":" 50 | return ret 51 | if self.sep_style == SeparatorStyle.MPT: 52 | if self.system: 53 | ret = self.system + self.sep 54 | else: 55 | ret = '' 56 | for role, message in self.messages: 57 | if message: 58 | if type(message) is tuple: 59 | message, _, _ = message 60 | ret += role + message + self.sep 61 | else: 62 | ret += role 63 | return ret 64 | else: 65 | raise ValueError(f"Invalid style: {self.sep_style}") 66 | 67 | 68 | def append_message(self, role, message): 69 | self.messages.append([role, message]) 70 | 71 | def get_images(self, return_pil=False): 72 | images = [] 73 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 74 | if i % 2 == 0: 75 | if type(msg) is tuple: 76 | import base64 77 | from io import BytesIO 78 | from PIL import Image 79 | msg, image, image_process_mode = msg 80 | if image_process_mode == "Pad": 81 | def expand2square(pil_img, background_color=(122, 116, 104)): 82 | width, height = pil_img.size 83 | if width == height: 84 | return pil_img 85 | elif width > height: 86 | result = Image.new(pil_img.mode, (width, width), background_color) 87 | # result.paste(pil_img, (0, (width - height) // 2)) 88 | result.paste(pil_img) 89 | return result 90 | else: 91 | result = Image.new(pil_img.mode, (height, height), background_color) 92 | # result.paste(pil_img, ((height - width) // 2, 0)) 93 | result.paste(pil_img) 94 | return result 95 | image = expand2square(image) 96 | elif image_process_mode == "Crop": 97 | max_hw, min_hw = max(image.size), min(image.size) 98 | aspect_ratio = max_hw / min_hw 99 | max_len, min_len = 800, 400 100 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 101 | longest_edge = int(shortest_edge * aspect_ratio) 102 | W, H = image.size 103 | if H > W: 104 | H, W = longest_edge, shortest_edge 105 | else: 106 | H, W = shortest_edge, longest_edge 107 | image = image.resize((W, H)) 108 | elif image_process_mode == "Resize": 109 | image = image.resize((224, 224)) 110 | else: 111 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}") 112 | 113 | if return_pil: 114 | images.append(image) 115 | else: 116 | buffered = BytesIO() 117 | image.convert('RGB').save(buffered, format="JPEG") 118 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 119 | images.append(img_b64_str) 120 | return images 121 | 122 | def to_gradio_chatbot(self): 123 | ret = [] 124 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 125 | if i % 2 == 0: 126 | if type(msg) is tuple: 127 | import base64 128 | from io import BytesIO 129 | msg, image, image_process_mode = msg 130 | max_hw, min_hw = max(image.size), min(image.size) 131 | aspect_ratio = max_hw / min_hw 132 | max_len, min_len = 800, 400 133 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 134 | longest_edge = int(shortest_edge * aspect_ratio) 135 | W, H = image.size 136 | if H > W: 137 | H, W = longest_edge, shortest_edge 138 | else: 139 | H, W = shortest_edge, longest_edge 140 | image = image.resize((W, H)) 141 | # image = image.resize((224, 224)) 142 | buffered = BytesIO() 143 | image.save(buffered, format="JPEG") 144 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 145 | img_str = f'user upload image' 146 | msg = msg.replace('', img_str) 147 | ret.append([msg, None]) 148 | else: 149 | ret[-1][-1] = msg 150 | return ret 151 | 152 | def copy(self): 153 | return Conversation( 154 | system=self.system, 155 | roles=self.roles, 156 | messages=[[x, y] for x, y in self.messages], 157 | offset=self.offset, 158 | sep_style=self.sep_style, 159 | sep=self.sep, 160 | sep2=self.sep2) 161 | 162 | def dict(self): 163 | if len(self.get_images()) > 0: 164 | return { 165 | "system": self.system, 166 | "roles": self.roles, 167 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 168 | "offset": self.offset, 169 | "sep": self.sep, 170 | "sep2": self.sep2, 171 | } 172 | return { 173 | "system": self.system, 174 | "roles": self.roles, 175 | "messages": self.messages, 176 | "offset": self.offset, 177 | "sep": self.sep, 178 | "sep2": self.sep2, 179 | } 180 | 181 | 182 | conv_v1 = Conversation( 183 | system="A chat between a curious human and an artificial intelligence assistant. " 184 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 185 | roles=("Human", "Assistant"), 186 | messages=( 187 | ("Human", "Give three tips for staying healthy."), 188 | ("Assistant", 189 | "Sure, here are three tips for staying healthy:\n" 190 | "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. " 191 | "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, " 192 | "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or " 193 | "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening " 194 | "activities at least two days per week.\n" 195 | "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, " 196 | "vegetables, whole grains, lean proteins, and healthy fats can help support " 197 | "your overall health. Try to limit your intake of processed and high-sugar foods, " 198 | "and aim to drink plenty of water throughout the day.\n" 199 | "3. Get enough sleep: Getting enough quality sleep is essential for your physical " 200 | "and mental health. Adults should aim for seven to nine hours of sleep per night. " 201 | "Establish a regular sleep schedule and try to create a relaxing bedtime routine to " 202 | "help improve the quality of your sleep.") 203 | ), 204 | offset=2, 205 | sep_style=SeparatorStyle.SINGLE, 206 | sep="###", 207 | ) 208 | 209 | conv_v1_2 = Conversation( 210 | system="A chat between a curious human and an artificial intelligence assistant. " 211 | "The assistant gives helpful, detailed, and polite answers to the human's questions.", 212 | roles=("Human", "Assistant"), 213 | messages=( 214 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"), 215 | ("Assistant", 216 | "Renewable energy sources are those that can be replenished naturally in a relatively " 217 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. " 218 | "Non-renewable energy sources, on the other hand, are finite and will eventually be " 219 | "depleted, such as coal, oil, and natural gas. Here are some key differences between " 220 | "renewable and non-renewable energy sources:\n" 221 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable " 222 | "energy sources are finite and will eventually run out.\n" 223 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact " 224 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, " 225 | "and other negative effects.\n" 226 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically " 227 | "have lower operational costs than non-renewable sources.\n" 228 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote " 229 | "locations than non-renewable sources.\n" 230 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different " 231 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n" 232 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while " 233 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n") 234 | ), 235 | offset=2, 236 | sep_style=SeparatorStyle.SINGLE, 237 | sep="###", 238 | ) 239 | 240 | conv_vicuna_v1_1 = Conversation( 241 | system="A chat between a curious user and an artificial intelligence assistant. " 242 | "The assistant gives helpful, detailed, and polite answers to the user's questions.", 243 | roles=("USER", "ASSISTANT"), 244 | version="v1", 245 | messages=(), 246 | offset=0, 247 | sep_style=SeparatorStyle.TWO, 248 | sep=" ", 249 | sep2="", 250 | ) 251 | 252 | 253 | 254 | conv_mpt = Conversation( 255 | system="""<|im_start|>system 256 | You should follow the instructions carefully and explain your answers in detail.""", 257 | # system = None, 258 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 259 | version="mpt", 260 | messages=(), 261 | offset=0, 262 | sep_style=SeparatorStyle.MPT, 263 | sep="<|im_end|>", 264 | ) 265 | 266 | conv_mpt_eval = Conversation( 267 | system="", 268 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 269 | version="mpt", 270 | messages=(), 271 | offset=0, 272 | sep_style=SeparatorStyle.MPT, 273 | sep="<|im_end|>", 274 | ) 275 | 276 | conv_mpt_text = Conversation( 277 | system="""<|im_start|>system 278 | - You are a helpful assistant chatbot trained by MosaicML. 279 | - You answer questions. 280 | - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. 281 | - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""", 282 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 283 | version="mpt", 284 | messages=(), 285 | offset=0, 286 | sep_style=SeparatorStyle.MPT, 287 | sep="<|im_end|>", 288 | ) 289 | 290 | conv_bair_v1 = Conversation( 291 | system="BEGINNING OF CONVERSATION:", 292 | roles=("USER", "GPT"), 293 | messages=(), 294 | offset=0, 295 | sep_style=SeparatorStyle.TWO, 296 | sep=" ", 297 | sep2="", 298 | ) 299 | 300 | 301 | 302 | 303 | simple_conv = Conversation( 304 | system="", 305 | roles=("Human", "Assistant"), 306 | messages=( 307 | ), 308 | offset=0, 309 | sep_style=SeparatorStyle.SINGLE, 310 | sep="###", 311 | ) 312 | 313 | simple_conv_multimodal = Conversation( 314 | system="You are vary, a large language and vision assistant trained by Foundation Model Group, Megvii Technology." 315 | "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 316 | "Follow the instructions carefully and explain your answers in detail.", 317 | # system="", 318 | roles=("Human", "Assistant"), 319 | messages=( 320 | ("Human", "Hi!"), 321 | ("Assistant", "Hi there! How can I help you today?\n") 322 | ), 323 | offset=2, 324 | sep_style=SeparatorStyle.SINGLE, 325 | sep="###", 326 | ) 327 | 328 | simple_conv_mpt_multimodal = Conversation( 329 | system="""<|im_start|>system 330 | - You are vary, a large language and vision assistant trained by Foundation Model Group, Megvii Technology. 331 | - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. 332 | - You should follow the instructions carefully and explain your answers in detail.""", 333 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), 334 | version="mpt", 335 | messages=(), 336 | offset=0, 337 | sep_style=SeparatorStyle.MPT, 338 | sep="<|im_end|>", 339 | ) 340 | 341 | simple_conv_legacy = Conversation( 342 | system="You are vary, a large language model trained by Foundation Model Group, Megvii Technology." 343 | "You are designed to assist human with a variety of tasks using natural language." 344 | "Follow the instructions carefully.", 345 | roles=("Human", "Assistant"), 346 | messages=( 347 | ("Human", "Hi!\n\n### Response:"), 348 | ("Assistant", "Hi there! How can I help you today?\n") 349 | ), 350 | offset=2, 351 | sep_style=SeparatorStyle.SINGLE, 352 | sep="###", 353 | ) 354 | 355 | conv_llava_v1 = Conversation( 356 | system="You are vary, a large language and vision assistant trained by Foundation Model Group, Megvii Technology." 357 | "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language." 358 | "Follow the instructions carefully and explain your answers in detail.", 359 | roles=("USER", "ASSISTANT"), 360 | version="v1", 361 | messages=(), 362 | offset=0, 363 | sep_style=SeparatorStyle.TWO, 364 | sep=" ", 365 | sep2="", 366 | ) 367 | 368 | default_conversation = conv_mpt 369 | conv_templates = { 370 | "default": simple_conv_multimodal, 371 | "simple": simple_conv, 372 | "simple_legacy": simple_conv_legacy, 373 | "multimodal": simple_conv, 374 | "mpt_multimodal": simple_conv_mpt_multimodal, 375 | "llava_v1": conv_llava_v1, 376 | "mpt_eval": conv_mpt_eval, 377 | # fastchat 378 | "v1": conv_vicuna_v1_1, 379 | "baichuan": conv_vicuna_v1_1, 380 | "bair_v1": conv_bair_v1, 381 | "vicuna_v1_1": conv_vicuna_v1_1, 382 | "mpt": conv_mpt, 383 | "qwen": conv_mpt, 384 | "mpt_text": conv_mpt_text, 385 | } 386 | 387 | 388 | if __name__ == "__main__": 389 | print(default_conversation.get_prompt()) 390 | -------------------------------------------------------------------------------- /Vary-master/vary/utils/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | from typing import List, Optional, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | 7 | import transformers 8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 9 | 10 | from einops import rearrange 11 | 12 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | def forward( 16 | self, 17 | hidden_states: torch.Tensor, 18 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | output_attentions: bool = False, 21 | use_cache: bool = False, 22 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], 23 | Optional[Tuple[torch.Tensor]]]: 24 | """Input shape: Batch x Time x Channel 25 | 26 | attention_mask: [bsz, q_len] 27 | """ 28 | bsz, q_len, _ = hidden_states.size() 29 | 30 | query_states = self.q_proj(hidden_states).view( 31 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 32 | key_states = self.k_proj(hidden_states).view( 33 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 34 | value_states = self.v_proj(hidden_states).view( 35 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 36 | # [bsz, q_len, nh, hd] 37 | # [bsz, nh, q_len, hd] 38 | 39 | kv_seq_len = key_states.shape[-2] 40 | offset = 0 41 | if past_key_value is not None: 42 | offset = past_key_value[0].shape[-2] 43 | kv_seq_len += offset 44 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 45 | query_states, key_states = apply_rotary_pos_emb(query_states, 46 | key_states, 47 | cos, 48 | sin, 49 | offset=offset) 50 | # [bsz, nh, t, hd] 51 | assert not output_attentions, "output_attentions is not supported" 52 | assert not use_cache, "use_cache is not supported" 53 | assert past_key_value is None, "past_key_value is not supported" 54 | 55 | # Flash attention codes from 56 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 57 | 58 | # transform the data into the format required by flash attention 59 | qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd] 60 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 61 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 62 | # the attention_mask should be the same as the key_padding_mask 63 | key_padding_mask = attention_mask 64 | 65 | 66 | if key_padding_mask is None: 67 | qkv = rearrange(qkv, 'b s ... -> (b s) ...') 68 | max_s = q_len 69 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, 70 | device=qkv.device) 71 | output = flash_attn_unpadded_qkvpacked_func( 72 | qkv, cu_q_lens, max_s, 0.0, 73 | softmax_scale=None, causal=True 74 | ) 75 | output = rearrange(output, '(b s) ... -> b s ...', b=bsz) 76 | else: 77 | nheads = qkv.shape[-2] 78 | x = rearrange(qkv, 'b s three h d -> b s (three h d)') 79 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 80 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads) 81 | output_unpad = flash_attn_unpadded_qkvpacked_func( 82 | x_unpad, cu_q_lens, max_s, 0.0, 83 | softmax_scale=None, causal=True 84 | ) 85 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), 86 | indices, bsz, q_len), 87 | 'b s (h d) -> b s h d', h=nheads) 88 | return self.o_proj(rearrange(output, 89 | 'b s h d -> b s (h d)')), None, None 90 | 91 | 92 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 93 | # requires the attention mask to be the same as the key_padding_mask 94 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, 95 | inputs_embeds, past_key_values_length): 96 | # [bsz, seq_len] 97 | return attention_mask 98 | 99 | 100 | def replace_llama_attn_with_flash_attn(): 101 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 102 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 103 | -------------------------------------------------------------------------------- /Vary-master/vary/utils/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | import torch 7 | import requests 8 | 9 | from transformers import StoppingCriteria 10 | from vary.utils.constants import LOGDIR 11 | 12 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 13 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 14 | 15 | handler = None 16 | 17 | 18 | def build_logger(logger_name, logger_filename): 19 | global handler 20 | 21 | formatter = logging.Formatter( 22 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 23 | datefmt="%Y-%m-%d %H:%M:%S", 24 | ) 25 | 26 | # Set the format of root handlers 27 | if not logging.getLogger().handlers: 28 | logging.basicConfig(level=logging.INFO) 29 | logging.getLogger().handlers[0].setFormatter(formatter) 30 | 31 | # Redirect stdout and stderr to loggers 32 | stdout_logger = logging.getLogger("stdout") 33 | stdout_logger.setLevel(logging.INFO) 34 | sl = StreamToLogger(stdout_logger, logging.INFO) 35 | sys.stdout = sl 36 | 37 | stderr_logger = logging.getLogger("stderr") 38 | stderr_logger.setLevel(logging.ERROR) 39 | sl = StreamToLogger(stderr_logger, logging.ERROR) 40 | sys.stderr = sl 41 | 42 | # Get logger 43 | logger = logging.getLogger(logger_name) 44 | logger.setLevel(logging.INFO) 45 | 46 | # Add a file handler for all loggers 47 | if handler is None: 48 | os.makedirs(LOGDIR, exist_ok=True) 49 | filename = os.path.join(LOGDIR, logger_filename) 50 | handler = logging.handlers.TimedRotatingFileHandler( 51 | filename, when='D', utc=True) 52 | handler.setFormatter(formatter) 53 | 54 | for name, item in logging.root.manager.loggerDict.items(): 55 | if isinstance(item, logging.Logger): 56 | item.addHandler(handler) 57 | 58 | return logger 59 | 60 | 61 | class StreamToLogger(object): 62 | """ 63 | Fake file-like stream object that redirects writes to a logger instance. 64 | """ 65 | def __init__(self, logger, log_level=logging.INFO): 66 | self.terminal = sys.stdout 67 | self.logger = logger 68 | self.log_level = log_level 69 | self.linebuf = '' 70 | 71 | def __getattr__(self, attr): 72 | return getattr(self.terminal, attr) 73 | 74 | def write(self, buf): 75 | temp_linebuf = self.linebuf + buf 76 | self.linebuf = '' 77 | for line in temp_linebuf.splitlines(True): 78 | # From the io.TextIOWrapper docs: 79 | # On output, if newline is None, any '\n' characters written 80 | # are translated to the system default line separator. 81 | # By default sys.stdout.write() expects '\n' newlines and then 82 | # translates them so this is still cross platform. 83 | if line[-1] == '\n': 84 | self.logger.log(self.log_level, line.rstrip()) 85 | else: 86 | self.linebuf += line 87 | 88 | def flush(self): 89 | if self.linebuf != '': 90 | self.logger.log(self.log_level, self.linebuf.rstrip()) 91 | self.linebuf = '' 92 | 93 | 94 | def disable_torch_init(): 95 | """ 96 | Disable the redundant torch default initialization to accelerate model creation. 97 | """ 98 | import torch 99 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 100 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 101 | 102 | 103 | def violates_moderation(text): 104 | """ 105 | Check whether the text violates OpenAI moderation API. 106 | """ 107 | url = "https://api.openai.com/v1/moderations" 108 | headers = {"Content-Type": "application/json", 109 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 110 | text = text.replace("\n", "") 111 | data = "{" + '"input": ' + f'"{text}"' + "}" 112 | data = data.encode("utf-8") 113 | try: 114 | ret = requests.post(url, headers=headers, data=data, timeout=5) 115 | flagged = ret.json()["results"][0]["flagged"] 116 | except requests.exceptions.RequestException as e: 117 | flagged = False 118 | except KeyError as e: 119 | flagged = False 120 | 121 | return flagged 122 | 123 | 124 | def pretty_print_semaphore(semaphore): 125 | if semaphore is None: 126 | return "None" 127 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 128 | 129 | 130 | class KeywordsStoppingCriteria(StoppingCriteria): 131 | def __init__(self, keywords, tokenizer, input_ids): 132 | self.keywords = keywords 133 | self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords] 134 | self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1] 135 | self.tokenizer = tokenizer 136 | self.start_len = None 137 | self.input_ids = input_ids 138 | 139 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 140 | if self.start_len is None: 141 | self.start_len = self.input_ids.shape[1] 142 | else: 143 | for keyword_id in self.keyword_ids: 144 | if output_ids[0, -1] == keyword_id: 145 | return True 146 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0] 147 | for keyword in self.keywords: 148 | if keyword in outputs: 149 | return True 150 | return False 151 | 152 | 153 | def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model): 154 | """Resize tokenizer and embedding. 155 | 156 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 157 | """ 158 | # num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 159 | # # num_new_tokens = 1 160 | # # tokenizer.add_tokens(special_tokens_dict, special_tokens=True) 161 | # model.resize_token_embeddings(len(tokenizer)) 162 | 163 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 164 | model.resize_token_embeddings(len(tokenizer)) 165 | 166 | if num_new_tokens > 0: 167 | input_embeddings = model.get_input_embeddings().weight.data 168 | output_embeddings = model.get_output_embeddings().weight.data 169 | 170 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 171 | dim=0, keepdim=True) 172 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 173 | dim=0, keepdim=True) 174 | 175 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 176 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 177 | 178 | 179 | def maybe_zero_3(param, ignore_status=False, name=None): 180 | from deepspeed import zero 181 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 182 | if hasattr(param, "ds_id"): 183 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 184 | if not ignore_status: 185 | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") 186 | with zero.GatheredParameters([param]): 187 | param = param.data.detach().cpu().clone() 188 | else: 189 | param = param.detach().cpu().clone() 190 | return param 191 | 192 | 193 | # Borrowed from peft.utils.get_peft_model_state_dict 194 | def get_peft_state_maybe_zero_3(named_params, bias): 195 | if bias == "none": 196 | to_return = {k: t for k, t in named_params if "lora_" in k} 197 | elif bias == "all": 198 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 199 | elif bias == "lora_only": 200 | to_return = {} 201 | maybe_lora_bias = {} 202 | lora_bias_names = set() 203 | for k, t in named_params: 204 | if "lora_" in k: 205 | to_return[k] = t 206 | bias_name = k.split("lora_")[0] + "bias" 207 | lora_bias_names.add(bias_name) 208 | elif "bias" in k: 209 | maybe_lora_bias[k] = t 210 | for k, t in maybe_lora_bias: 211 | if bias_name in lora_bias_names: 212 | to_return[bias_name] = t 213 | else: 214 | raise NotImplementedError 215 | to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()} 216 | return to_return 217 | 218 | 219 | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): 220 | to_return = {k: t for k, t in named_params if "lora_" not in k} 221 | if require_grad_only: 222 | to_return = {k: t for k, t in to_return.items() if t.requires_grad} 223 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} 224 | return to_return 225 | 226 | 227 | def find_all_linear_names(model): 228 | cls = torch.nn.Linear 229 | lora_module_names = set() 230 | for name, module in model.named_modules(): 231 | if isinstance(module, cls) and 'vision_model' not in name and 'mm_projector' not in name and 'vision_encoder' not in name and 'conv_final' not in name and'lm_head' not in name: 232 | lora_module_names.add(name) 233 | 234 | print(lora_module_names) 235 | return list(lora_module_names) -------------------------------------------------------------------------------- /Vary-master/zero_config/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": true 4 | }, 5 | "train_micro_batch_size_per_gpu": "auto", 6 | "zero_optimization": { 7 | "stage": 2, 8 | "overlap_comm": true, 9 | "contiguous_gradients": true, 10 | "sub_group_size": 1e9, 11 | "reduce_bucket_size": "auto" 12 | } 13 | } -------------------------------------------------------------------------------- /assets/vary-toy-logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ucas-HaoranWei/Vary-toy/c6e405977aa6e88d4807a6b79f394299623fc6f3/assets/vary-toy-logo.jpg --------------------------------------------------------------------------------