├── README.md
├── Small Language Model Meets with Reinforced Vision Vocabulary.pdf
├── Vary-master
├── pyproject.toml
├── pyvenv.cfg
├── vary
│ ├── __init__.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── base_dataset.py
│ │ ├── caption_opt.py
│ │ └── conversation_dataset_qwen.py
│ ├── demo
│ │ ├── run_opt.py
│ │ └── run_qwen_vary.py
│ ├── model
│ │ ├── __init__.py
│ │ ├── llm
│ │ │ ├── opt
│ │ │ │ └── modeling_opt.py
│ │ │ └── qwen
│ │ │ │ ├── configuration_qwen.py
│ │ │ │ ├── modeling_qwen.py
│ │ │ │ ├── qwen_generation_utils.py
│ │ │ │ └── tokenization_qwen.py
│ │ ├── plug
│ │ │ ├── blip_process.py
│ │ │ └── transforms.py
│ │ ├── vary_opt.py
│ │ ├── vary_qwen_vary.py
│ │ ├── vary_toy_qwen1_8.py
│ │ └── vision_encoder
│ │ │ ├── __init__.py
│ │ │ └── sam.py
│ ├── train
│ │ ├── train_flash_attn.py
│ │ ├── train_lora.py
│ │ ├── train_lora_flash_attn.py
│ │ ├── train_opt.py
│ │ ├── train_qwen_vary.py
│ │ ├── trainer.py
│ │ └── trainer_vit_fixlr.py
│ └── utils
│ │ ├── arguments.py
│ │ ├── constants.py
│ │ ├── conversation.py
│ │ ├── llama_flash_attn_monkey_patch.py
│ │ └── utils.py
└── zero_config
│ └── zero2.json
└── assets
└── vary-toy-logo.jpg
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 | [Haoran Wei*](https://scholar.google.com/citations?user=J4naK0MAAAAJ&hl=en), Lingyu Kong*, Jinyue Chen, Liang Zhao, [Zheng Ge](https://joker316701882.github.io/), [En Yu](https://scholar.google.com.hk/citations?user=rWCQMNgAAAAJ&hl=zh-CN&oi=sra), [Jianjian Sun](https://scholar.google.com/citations?user=MVZrGkYAAAAJ&hl=en), Chunrui Han, [Xiangyu Zhang](https://scholar.google.com/citations?user=yuB-cfoAAAAJ&hl=en)
10 |
11 |
12 |
13 |
14 |
15 |
16 | Two-Stream Hypothesis for LVLMs
17 |
18 |
19 |
20 | ## Release
21 | - [2024/9/03] 🔥🔥🔥 We release a very strong and comprehensive OCR model [GOT-OCR2.0](https://github.com/Ucas-HaoranWei/GOT-OCR2.0).
22 | - [2024/7/21] 🎉🎉🎉 OneChart is accepted by ACM'MM 2024 **Oral**! (3.97%)
23 | - [2024/7/2] 🔥🔥🔥 Vary is accepted by ECCV2024. To thank everyone for their attention, I will release a model that performs on par with the Vary-document soon.
24 | - [2024/5/27] 🔥🔥🔥 We present a document understanding benchmark in [Fox](https://github.com/ucaslcl/Fox) .
25 | - [2024/5/24] 🔥🔥🔥 We propose a multi-page document understanding work -- [Fox](https://arxiv.org/abs/2405.14295), which supports 8-page pdf-image input !!!
26 | - [2024/4/21] 🔥🔥🔥 For OneChart, we have released the web demo in [Project Page](https://onechartt.github.io/). Have fun!!
27 | - [2024/4/21] 🔥🔥🔥 We present a Vary-tiny LAVIS codebase (for training from scratch) and the Vary-600k dataset (300K English and 300K Chinese pages) [here](https://github.com/Ucas-HaoranWei/Vary-tiny-600k) !!!
28 | - [2024/4/15]🔥🔥🔥We release a chart parsing model OneChart [here](https://github.com/LingyvKong/OneChart).
29 | - [2024/4/12]🔥🔥🔥We will release a chart parsing model based on Vary-tiny next week. The model supports both English and Chinese charts.
30 | - [2024/3/16]🔥🔥🔥I found many friends very interested in Vary-tiny(OPT-125M), so I opened source it [here](https://huggingface.co/HaoranWei/Vary-tiny-opt125M/tree/main), a PDF-dense OCR and object detection version.
31 | - [2024/1/23] 🔥Eval codes will be available soon.
32 | - [2024/1/23] 🔥🔥🔥You only need a single 1080Ti to experience all features of current LVLMs.
33 |
34 |
35 |
36 |
37 | [](https://github.com/tatsu-lab/stanford_alpaca/blob/main/LICENSE)
38 | [](https://github.com/tatsu-lab/stanford_alpaca/blob/main/DATA_LICENSE)
39 | **Usage and License Notices**: The data, code, and checkpoint are intended and licensed for research use only. They are also restricted to use that follow the license agreement of LLaMA, Vicuna, GPT-4, Qwen, and LLaVA.
40 |
41 |
42 | ## Contents
43 | - [Install](#install)
44 | - [Vary-toy Weights](#vary-weights)
45 | - [Demo](#Demo)
46 | - [Train](#train)
47 |
48 | ## Note
49 | If you have built the original [Vary](https://github.com/Ucas-HaoranWei/Vary), please rebuild this repo !!!
50 |
51 | ## Install
52 |
53 | 1. Clone this repository and navigate to the Vary folder
54 | ```bash
55 | git clone https://github.com/Ucas-HaoranWei/Vary-toy.git
56 | cd /path/to/vary-toy
57 | ```
58 | 2. Install Package
59 | ```Shell
60 | conda create -n vary python=3.10 -y
61 | conda activate vary
62 | pip install e .
63 | ```
64 |
65 | 3. Install Flash-Attention
66 | ```
67 | pip install ninja
68 | pip install flash-attn --no-build-isolation
69 | ```
70 |
71 | ## Vary Weights
72 | - Download the Vary-toy weights [here](https://huggingface.co/Haoran-megvii/Vary-toy).
73 | - Download the CLIP-VIT-L [here](https://huggingface.co/openai/clip-vit-large-patch14/).
74 |
75 |
76 |
77 | ## Demo
78 | 1. Update the CLIP-VIT path in the codes (/cache/vit-large-patch14/) to your path.
79 |
80 | 2.
81 | ```Shell
82 | python vary/demo/run_qwen_vary.py --model-name /vary/model/path/ --image-file /an/image/file.png
83 | ```
84 | ## Train
85 | ```Shell
86 | deepspeed Vary/train/train_qwen_vary.py --deepspeed /Vary/zero_config/zero2.json
87 | --model_name_or_path /Vary-toy/path/
88 | --vision_tower /vit-large-patch14/path/
89 | --freeze_vision_tower True
90 | --freeze_lm_model False
91 | --vision_select_layer -2
92 | --use_im_start_end True
93 | --bf16 True
94 | --per_device_eval_batch_size 4
95 | --gradient_accumulation_steps 1
96 | --evaluation_strategy "no"
97 | --save_strategy "steps"
98 | --save_steps 5000
99 | --save_total_limit 1
100 | --weight_decay 0.
101 | --warmup_ratio 0.03
102 | --lr_scheduler_type "cosine"
103 | --logging_steps 1 --tf32 True
104 | --model_max_length 4096
105 | --gradient_checkpointing True
106 | --dataloader_num_workers 4
107 | --report_to none
108 | --per_device_train_batch_size 4
109 | --num_train_epochs 1
110 | --learning_rate 5e-5
111 | --datasets data_name1+data_name2+data_name3
112 | --output_dir /path/to/output/
113 | ```
114 | We encourage you to extract the new vision vocabulary weights for your new base language model !!!
115 |
116 | ## Contact
117 | If you have any questions about the code or the paper, please email (`weihaoran18@mails.ucas.ac.cn`).
118 |
119 | ## Discussion
120 | Vary-toy is not a toy, and we have designed two excellent models based on it, one is Vary-document (specifically for document/pdf processing), and the other is Vary-plot for chart analysis. You can see their amazing performance here [Vary-family](https://github.com/Ucas-HaoranWei/Vary-family).
121 |
122 | ## Citation
123 | If you find our work useful in your research, please consider citing Vary:
124 | ```bibtex
125 | @article{wei2023vary,
126 | title={Vary: Scaling up the Vision Vocabulary for Large Vision-Language Models},
127 | author={Wei, Haoran and Kong, Lingyu and Chen, Jinyue and Zhao, Liang and Ge, Zheng and Yang, Jinrong and Sun, Jianjian and Han, Chunrui and Zhang, Xiangyu},
128 | journal={arXiv preprint arXiv:2312.06109},
129 | year={2023}
130 | }
131 |
132 | @article{wei2024small,
133 | title={Small Language Model Meets with Reinforced Vision Vocabulary},
134 | author={Wei, Haoran and Kong, Lingyu and Chen, Jinyue and Zhao, Liang and Ge, Zheng and Yu, En and Sun, Jianjian and Han, Chunrui and Zhang, Xiangyu},
135 | journal={arXiv preprint arXiv:2401.12503},
136 | year={2024}
137 | }
138 | ```
139 |
140 |
--------------------------------------------------------------------------------
/Small Language Model Meets with Reinforced Vision Vocabulary.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ucas-HaoranWei/Vary-toy/c6e405977aa6e88d4807a6b79f394299623fc6f3/Small Language Model Meets with Reinforced Vision Vocabulary.pdf
--------------------------------------------------------------------------------
/Vary-master/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "vary"
7 | version = "0.1.0"
8 | description = "Towards GPT-4 like large language and visual assistant."
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "einops", "markdown2[all]", "numpy",
17 | "requests", "sentencepiece", "tokenizers>=0.12.1",
18 | "torch", "torchvision", "wandb",
19 | "shortuuid", "httpx==0.24.0",
20 | "deepspeed==0.12.3",
21 | "peft==0.4.0",
22 | "albumentations ",
23 | "opencv-python",
24 | "tiktoken",
25 | "accelerate==0.24.1",
26 | "transformers==4.32.1",
27 | "bitsandbytes==0.41.0",
28 | "scikit-learn==1.2.2",
29 | "sentencepiece==0.1.99",
30 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13",
31 | "gradio_client==0.2.9"
32 | ]
33 |
34 | [tool.setuptools.packages.find]
35 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
36 |
37 | [tool.wheel]
38 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
39 |
--------------------------------------------------------------------------------
/Vary-master/pyvenv.cfg:
--------------------------------------------------------------------------------
1 | home = /usr/bin
2 | implementation = CPython
3 | version_info = 3.8.10.final.0
4 | virtualenv = 20.16.7
5 | include-system-site-packages = true
6 | base-prefix = /usr
7 | base-exec-prefix = /usr
8 | base-executable = /usr/bin/python3
9 |
--------------------------------------------------------------------------------
/Vary-master/vary/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ucas-HaoranWei/Vary-toy/c6e405977aa6e88d4807a6b79f394299623fc6f3/Vary-master/vary/__init__.py
--------------------------------------------------------------------------------
/Vary-master/vary/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import transformers
4 | from dataclasses import dataclass, field
5 |
6 | from vary.utils.constants import *
7 |
8 |
9 | @dataclass
10 | class DataCollatorForSupervisedDataset(object):
11 | tokenizer: transformers.PreTrainedTokenizer
12 |
13 | def __call__(self, instances):
14 |
15 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
16 | images = [torch.stack(instance['image']) for instance in instances]
17 |
18 |
19 | images_high = [torch.stack(instance['image_high']) for instance in instances]
20 |
21 | images = list(zip(images, images_high))
22 |
23 |
24 | input_ids = torch.nn.utils.rnn.pad_sequence(
25 | input_ids,
26 | batch_first=True,
27 | padding_value=self.tokenizer.pad_token_id)
28 |
29 | labels = torch.nn.utils.rnn.pad_sequence(
30 | labels,
31 | batch_first=True,
32 | padding_value=IGNORE_INDEX)
33 |
34 | batch = dict(
35 | input_ids=input_ids,
36 | labels=labels,
37 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
38 | images=images,
39 | )
40 | return batch
41 |
42 |
43 | def make_supervised_data_module(interleave, with_box, tokenizer, data_args):
44 |
45 | if data_args.conversation_version == 'mpt':
46 | from vary.data.conversation_dataset_qwen import ConversationDataset
47 | dataset_cls = ConversationDataset
48 | elif data_args.conversation_version == 'opt':
49 | from vary.data.caption_opt import CaptionDataset
50 | dataset_cls = CaptionDataset
51 |
52 | train_dataset = dataset_cls(
53 | tokenizer=tokenizer,
54 | datasets=data_args.datasets,
55 | multimodal_cfg=dict(
56 | sep_image_conv_front=data_args.sep_image_conv_front,
57 | image_token_len=data_args.image_token_len,
58 | image_aspect_ratio=data_args.image_aspect_ratio,
59 | use_im_start_end=data_args.use_im_start_end,
60 | image_processor=data_args.image_processor,
61 | image_processor_high = data_args.image_processor_high,
62 | box_limit=data_args.box_limit,
63 | )
64 | )
65 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
66 | return dict(train_dataset=train_dataset,
67 | eval_dataset=None,
68 | data_collator=data_collator)
--------------------------------------------------------------------------------
/Vary-master/vary/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import io
2 | import os
3 | import copy
4 | import json
5 | import logging
6 | import torch
7 | import transformers
8 | from typing import List, Optional, Tuple, Union, Dict, Sequence
9 | from torch.utils.data import Dataset
10 | from PIL import Image, ImageFile
11 | ImageFile.LOAD_TRUNCATED_IMAGES = True
12 | from vary.utils.constants import *
13 |
14 |
15 |
16 | class BaseDataset(Dataset):
17 | def __init__(
18 | self,
19 | datasets: str,
20 | tokenizer: transformers.PreTrainedTokenizer,
21 | multimodal_cfg: dict
22 | ):
23 | super(BaseDataset, self).__init__()
24 | self.tokenizer = tokenizer
25 | self.multimodal_cfg = multimodal_cfg
26 |
27 | logging.warning(f"Using {multimodal_cfg['image_token_len']} tokens for representing image")
28 |
29 | def image_processor(self, image):
30 | processor = self.multimodal_cfg['image_processor'] # the first processor, usually is the clip pretrained model (vit)
31 | processor_high = self.multimodal_cfg['image_processor_high'] # the second processor, usually is the designed image encoder (sam/swin/cnn)
32 | image_high = image.copy()
33 | # TODO the 'keep', 'padding' only used for the first processor
34 | if self.multimodal_cfg['image_aspect_ratio'] == 'keep':
35 | max_hw, min_hw = max(image.size), min(image.size)
36 | aspect_ratio = max_hw / min_hw
37 | max_len, min_len = 448, 224
38 | shortest_edge = int(min(max_len / aspect_ratio, min_len))
39 | image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": shortest_edge})['pixel_values'][0]
40 | elif self.multimodal_cfg['image_aspect_ratio'] == 'pad':
41 | def expand2square(pil_img, background_color):
42 | width, height = pil_img.size
43 | if width == height:
44 | return pil_img
45 | elif width > height:
46 | result = Image.new(pil_img.mode, (width, width), background_color)
47 | result.paste(pil_img) # for simpler box processing
48 | return result
49 | else:
50 | result = Image.new(pil_img.mode, (height, height), background_color)
51 | result.paste(pil_img) # for simpler box processing
52 | return result
53 | image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
54 | image = processor.preprocess(image, return_tensors='pt', do_center_crop=False, size={"shortest_edge": 224})['pixel_values'][0]
55 | else:
56 | image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
57 |
58 | image_high = processor_high(image_high)
59 |
60 | return image, image_high
61 |
62 |
63 |
64 | def __len__(self):
65 | return len(self.list_data_dict)
66 |
67 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
68 | pass
--------------------------------------------------------------------------------
/Vary-master/vary/data/caption_opt.py:
--------------------------------------------------------------------------------
1 |
2 | import io
3 | import os
4 | import copy
5 | import json
6 | import logging
7 | import torch
8 | import random
9 |
10 | from typing import List, Optional, Tuple, Union, Dict, Sequence
11 | from PIL import Image, ImageFile
12 | ImageFile.LOAD_TRUNCATED_IMAGES = True
13 |
14 | from vary.data.base_dataset import BaseDataset
15 | from vary.utils.constants import *
16 | from vary.utils import conversation as conversation_lib
17 |
18 |
19 | class CaptionDataset(BaseDataset):
20 | """Conversation format dataset stage2 fine-tuning."""
21 |
22 | def __init__(self, datasets, tokenizer, multimodal_cfg):
23 | super(CaptionDataset, self).__init__(datasets, tokenizer, multimodal_cfg)
24 | # v0 version format conversation
25 | conversation_lib.default_conversation = conversation_lib.conv_templates["default"]
26 | logging.warning("Formatting inputs into conversation type: v0-fixed")
27 | logging.warning("Loading data...")
28 |
29 | list_data_dict = []
30 | list_image_path = []
31 | for name in datasets.split("+"):
32 | dataset = CONVERSATION_DATA[name] # in vary.utils
33 |
34 | data_path = dataset['annotations']
35 | data = json.load(open(data_path, "r"))
36 |
37 | list_data_dict.extend(data)
38 |
39 | image_path = dataset['images']
40 | list_image_path.extend([image_path] * len(data))
41 |
42 | logging.warning(f"Data from {data_path} provide {len(data)} conversations.")
43 |
44 | assert len(list_data_dict) == len(list_image_path)
45 | logging.warning(f"{len(list_data_dict)} conversations in total.")
46 | a_new_list = list(zip(list_data_dict, list_image_path))
47 | random.shuffle(a_new_list)
48 | list_data_dict_new, list_image_path_new = zip(*a_new_list)
49 | self.list_data_dict = list_data_dict_new
50 | self.list_image_path = list_image_path_new
51 | self.im_patch_token, self.im_start_token, self.im_end_token = tokenizer.convert_tokens_to_ids(
52 | [DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
53 |
54 | def multimodal_processor(self, sources):
55 | for source in sources:
56 |
57 | source[0]['value'] = DEFAULT_IMAGE_TOKEN
58 | for sentence in source:
59 | replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']
60 | if self.multimodal_cfg['use_im_start_end']:
61 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
62 |
63 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token)
64 | return sources
65 |
66 | def _tokenize_fn(self, strings):
67 | """Tokenize a list of strings."""
68 | tokenized_list = [
69 | self.tokenizer(
70 | text,
71 | return_tensors="pt",
72 | padding="longest",
73 | max_length=self.tokenizer.model_max_length,
74 | truncation=True,
75 | ) for text in strings
76 | ]
77 | input_ids = labels = [
78 | tokenized.input_ids[0] for tokenized in tokenized_list
79 | ]
80 |
81 |
82 | for idx, ii in enumerate(input_ids):
83 | if ii[-1] != 2:
84 | input_ids[idx][-1] = 2
85 | labels[idx][-1] = 2
86 |
87 | input_ids_lens = labels_lens = [
88 | tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item()
89 | for tokenized in tokenized_list
90 | ]
91 | return dict(
92 | input_ids=input_ids,
93 | labels=labels,
94 | input_ids_lens=input_ids_lens,
95 | labels_lens=labels_lens,
96 | )
97 |
98 | def _mask_targets(self, target, tokenized_lens, speakers):
99 | # cur_idx = 0
100 | cur_idx = tokenized_lens[0]
101 | tokenized_lens = tokenized_lens[1:]
102 | target[:cur_idx] = IGNORE_INDEX
103 | for tokenized_len, speaker in zip(tokenized_lens, speakers):
104 | if speaker.lower() == "human":
105 | target[cur_idx:tokenized_len] = IGNORE_INDEX
106 | cur_idx += tokenized_len
107 |
108 |
109 | def _add_speaker_and_signal(self, header, source, get_conversation=True):
110 | """Add speaker and start/end signal on each round."""
111 | BEGIN_SIGNAL = ""
112 | END_SIGNAL = "\n"
113 | conversation = header
114 | for sentence in source:
115 | from_str = sentence["from"]
116 | if from_str.lower() == "human":
117 | from_str = conversation_lib.default_conversation.roles[0]
118 | else:
119 | from_str = conversation_lib.default_conversation.roles[1]
120 |
121 | sentence["value"] = sentence["value"] + END_SIGNAL
122 | if get_conversation:
123 | conversation += sentence["value"]
124 | conversation += BEGIN_SIGNAL
125 | return conversation
126 |
127 | def token_processor(self, sources):
128 | """
129 | Given a list of sources, each is a conversation list. This transform:
130 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
131 | 2. Concatenate conversations together;
132 | 3. Tokenize the concatenated conversation;
133 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
134 | """
135 | # add end signal and concatenate together
136 | conversations = []
137 | for source in sources:
138 | header = ''
139 | conversation = self._add_speaker_and_signal(header, source)
140 | conversations.append(conversation)
141 |
142 | conversations_tokenized = self._tokenize_fn(conversations)
143 | input_ids = conversations_tokenized["input_ids"]
144 | targets = copy.deepcopy(input_ids)
145 | for target, source in zip(targets, sources):
146 | tokenized_lens = self._tokenize_fn([header] + [s["value"] for s in source])["input_ids_lens"]
147 | speakers = [sentence["from"] for sentence in source]
148 | self._mask_targets(target, tokenized_lens, speakers)
149 |
150 | return dict(input_ids=input_ids, labels=targets)
151 |
152 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
153 | # data = self.list_data_dict[i]
154 | data = copy.deepcopy(self.list_data_dict[i])
155 |
156 | if isinstance(data, dict):
157 | if 'image' in data:
158 | image_path = self.list_image_path[i]
159 | image_file = data['image']
160 | # TODO this is a bug, because some json has wrong path
161 |
162 | try:
163 | image = Image.open(image_path + image_file).convert('RGB')
164 | except:
165 | print(f'cannot identify image file {image_path+image_file}.')
166 | return self.__getitem__(0)
167 |
168 | try:
169 | image, image_high = self.image_processor(image)
170 | except:
171 | print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')
172 | return self.__getitem__(0)
173 |
174 | conversations = self.multimodal_processor([data["conversations"]])
175 | else:
176 | conversations = [data]
177 |
178 | # align with fastchat & llava here, put the conversation into a list for tokenization
179 | data_dict = self.token_processor(conversations)
180 | data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
181 |
182 | if isinstance(data, dict) and 'image' in data:
183 | data_dict['image'] = [image]
184 | data_dict['image_high'] = [image_high]
185 | else:
186 | crop_size = self.multimodal_cfg['image_processor'].crop_size
187 | data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
188 | # TODO sam is 1024*1024
189 | data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]
190 | return data_dict
191 |
192 |
--------------------------------------------------------------------------------
/Vary-master/vary/data/conversation_dataset_qwen.py:
--------------------------------------------------------------------------------
1 |
2 | import io
3 | import os
4 | import copy
5 | import json
6 | import logging
7 | import torch
8 | import random
9 |
10 | from typing import List, Optional, Tuple, Union, Dict, Sequence
11 | from PIL import Image, ImageFile
12 | ImageFile.LOAD_TRUNCATED_IMAGES = True
13 |
14 | from vary.data.base_dataset import BaseDataset
15 | from vary.utils.constants import *
16 | from vary.utils import conversation as conversation_lib
17 |
18 | class ConversationDataset(BaseDataset):
19 | """Conversation format dataset stage2 fine-tuning."""
20 |
21 | def __init__(self, datasets, tokenizer, multimodal_cfg):
22 | super(ConversationDataset, self).__init__(datasets, tokenizer, multimodal_cfg)
23 | # v0 version format conversation
24 | conversation_lib.default_conversation = conversation_lib.conv_templates["mpt"]
25 | logging.warning("Formatting inputs into conversation type: mpt-fixed")
26 | logging.warning("Loading data...")
27 |
28 | list_data_dict = []
29 | list_image_path = []
30 |
31 |
32 | for name in datasets.split("+"):
33 | # for name in vary_data_dict[name_all]:
34 | dataset = CONVERSATION_DATA[name]
35 |
36 | data_path = dataset['annotations']
37 | data = json.load(open(data_path, "r"))
38 |
39 | list_data_dict.extend(data)
40 |
41 | image_path = dataset['images']
42 |
43 | list_image_path.extend([image_path] * len(data))
44 |
45 | logging.warning(f"Data from {data_path} provide {len(data)} conversations.")
46 |
47 | assert len(list_data_dict) == len(list_image_path)
48 | logging.warning(f"{len(list_data_dict)} conversations in total.")
49 | a_new_list = list(zip(list_data_dict, list_image_path))
50 | random.shuffle(a_new_list)
51 | list_data_dict_new, list_image_path_new = zip(*a_new_list)
52 | self.list_data_dict = list_data_dict_new
53 | self.list_image_path = list_image_path_new
54 |
55 | self.im_patch_token = 151859
56 |
57 | self.im_start_token = 151857
58 |
59 | self.im_end_token = 151858
60 |
61 | def multimodal_processor(self, sources):
62 | for source in sources:
63 | if self.multimodal_cfg['sep_image_conv_front']:
64 | assert DEFAULT_IMAGE_TOKEN in source[0]['value']
65 | source[0]['value'] = source[0]['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip()
66 | source[0]['value'] = DEFAULT_IMAGE_TOKEN + conversation_lib.default_conversation.sep + conversation_lib.default_conversation.roles[0] + ": " + source[0]['value']
67 | for sentence in source:
68 | replace_token = DEFAULT_IMAGE_PATCH_TOKEN * self.multimodal_cfg['image_token_len']
69 | # if self.multimodal_cfg['use_im_start_end']:
70 | replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
71 | sentence["value"] = str(sentence["value"]).replace(DEFAULT_IMAGE_TOKEN, replace_token)
72 | return sources
73 |
74 | def _tokenize_fn(self, strings):
75 | """Tokenize a list of strings."""
76 | tokenized_list = [
77 | self.tokenizer(
78 | text,
79 | return_tensors="pt",
80 | padding="longest",
81 | max_length=self.tokenizer.model_max_length,
82 | truncation=True,
83 | ) for text in strings
84 | ]
85 | input_ids = labels = [
86 | tokenized.input_ids[0] for tokenized in tokenized_list
87 | ]
88 | input_ids_lens = labels_lens = [
89 | tokenized.input_ids.ne(self.tokenizer.pad_token_id).sum().item()
90 | for tokenized in tokenized_list
91 | ]
92 | return dict(
93 | input_ids=input_ids,
94 | labels=labels,
95 | input_ids_lens=input_ids_lens,
96 | labels_lens=labels_lens,
97 | )
98 |
99 | def _mask_targets(self, target, tokenized_lens, speakers):
100 | # cur_idx = 0
101 | cur_idx = tokenized_lens[0]
102 | tokenized_lens = tokenized_lens[1:]
103 | target[:cur_idx] = IGNORE_INDEX
104 | for tokenized_len, speaker in zip(tokenized_lens, speakers):
105 | if speaker.lower() == "human":
106 | target[cur_idx+2:cur_idx + tokenized_len] = IGNORE_INDEX
107 | cur_idx += tokenized_len
108 |
109 | def token_processor(self, sources):
110 | conv = conversation_lib.default_conversation.copy()
111 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
112 |
113 | # Apply prompt templates
114 | conversations = []
115 | for i, source in enumerate(sources):
116 | if roles[source[0]["from"]] != conv.roles[0]:
117 | # Skip the first one if it is not from human
118 | source = source[1:]
119 |
120 | conv.messages = []
121 | for j, sentence in enumerate(source):
122 | role = roles[sentence["from"]]
123 | assert role == conv.roles[j % 2], f"{i}"
124 | conv.append_message(role, sentence["value"])
125 | conversations.append(conv.get_prompt())
126 |
127 | # Tokenize conversations
128 |
129 |
130 | input_ids = self.tokenizer(
131 | conversations,
132 | return_tensors="pt",
133 | padding="longest",
134 | max_length=self.tokenizer.model_max_length,
135 | truncation=True,
136 | ).input_ids
137 |
138 | # input_ids = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0)
139 | targets = input_ids.clone()
140 | assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
141 |
142 | # Mask targets
143 | sep = conv.sep + conv.roles[1]
144 | for conversation, target in zip(conversations, targets):
145 | total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
146 |
147 | rounds = conversation.split(conv.sep)
148 | re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
149 | for conv_idx in range(3, len(rounds), 2):
150 | re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
151 | cur_len = 0
152 | target[:cur_len] = IGNORE_INDEX
153 | for i, rou in enumerate(re_rounds):
154 | if rou == "":
155 | break
156 |
157 | parts = rou.split(sep)
158 | if len(parts) != 2:
159 | break
160 | parts[0] += sep
161 | round_len = len(self.tokenizer(rou).input_ids) + len(self.tokenizer(conv.sep).input_ids)
162 |
163 | instruction_len = len(self.tokenizer(parts[0]).input_ids)
164 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
165 |
166 | cur_len += round_len
167 | target[cur_len:] = IGNORE_INDEX
168 |
169 | if cur_len < self.tokenizer.model_max_length:
170 | if cur_len != total_len:
171 | target[:] = IGNORE_INDEX
172 | print(
173 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
174 | f" (ignored)"
175 | )
176 |
177 | return dict(
178 | input_ids=input_ids,
179 | labels=targets,
180 | )
181 |
182 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
183 | # data = self.list_data_dict[i]
184 | data = copy.deepcopy(self.list_data_dict[i])
185 |
186 | if isinstance(data, dict):
187 | if 'image' in data:
188 | image_path = self.list_image_path[i]
189 | image_file = data['image']
190 |
191 | try:
192 | image = Image.open(image_path + image_file).convert('RGB')
193 | except:
194 | print(f'cannot identify image file {image_path + image_file}.')
195 | return self.__getitem__(0)
196 |
197 | try:
198 | image, image_1 = self.image_processor(image)
199 | except:
200 | print(f'image {image_file} are broken or grayscale! we thus select 0-th sample instead!')
201 | return self.__getitem__(0)
202 |
203 | conversations = self.multimodal_processor([data["conversations"]])
204 |
205 | else:
206 | conversations = [data]
207 |
208 | # align with fastchat & llava here, put the conversation into a list for tokenization
209 | data_dict = self.token_processor(conversations)
210 | data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
211 |
212 | if isinstance(data, dict) and 'image' in data:
213 | data_dict['image'] = [image]
214 | data_dict['image_high'] = [image_1]
215 | else:
216 | crop_size = self.multimodal_cfg['image_processor'].crop_size
217 | data_dict['image'] = [torch.zeros(3, crop_size['height'], crop_size['width'])]
218 | data_dict['image_high'] = [torch.zeros(3, 1024, 1024)]
219 | return data_dict
220 |
221 |
--------------------------------------------------------------------------------
/Vary-master/vary/demo/run_opt.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from transformers import AutoTokenizer, AutoModelForCausalLM
3 | import torch
4 | import os
5 | from vary.utils.conversation import conv_templates, SeparatorStyle
6 | from vary.utils.utils import disable_torch_init
7 | from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
8 | from vary.model import *
9 | from vary.utils.utils import KeywordsStoppingCriteria
10 |
11 | from PIL import Image
12 |
13 | import os
14 | import requests
15 | from PIL import Image
16 | from io import BytesIO
17 |
18 | from transformers import TextStreamer
19 |
20 |
21 | from vary.model.plug.blip_process import BlipImageEvalProcessor
22 |
23 | from vary.model.vision_encoder.sam import build_sam_vit_b
24 | from vary.model.plug.transforms import train_transform, test_transform
25 | DEFAULT_IMAGE_TOKEN = ""
26 | DEFAULT_IMAGE_PATCH_TOKEN = ''
27 | DEFAULT_IM_START_TOKEN = '
'
28 | DEFAULT_IM_END_TOKEN = ''
29 |
30 |
31 |
32 | def load_image(image_file):
33 | if image_file.startswith('http') or image_file.startswith('https'):
34 | response = requests.get(image_file)
35 | image = Image.open(BytesIO(response.content)).convert('RGB')
36 | else:
37 | image = Image.open(image_file).convert('RGB')
38 | return image
39 |
40 |
41 | def eval_model(args):
42 | # Model
43 | disable_torch_init()
44 | model_name = os.path.expanduser(args.model_name)
45 |
46 | tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
47 |
48 | model = varyOPTForCausalLM.from_pretrained(model_name)
49 |
50 |
51 |
52 | model.to(device='cuda', dtype=torch.bfloat16)
53 |
54 | image_processor_high = test_transform
55 |
56 |
57 | image_token_len = 256
58 |
59 |
60 | prompt = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len + DEFAULT_IM_END_TOKEN
61 | inputs = tokenizer([prompt])
62 |
63 |
64 | image = load_image(args.image_file)
65 | image_1 = image.copy()
66 |
67 | image_tensor_1 = image_processor_high(image_1).to(torch.bfloat16)
68 |
69 |
70 | input_ids = torch.as_tensor(inputs.input_ids).cuda()
71 |
72 | stop_str = ''
73 | keywords = [stop_str]
74 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
75 |
76 |
77 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
78 |
79 | with torch.autocast("cuda", dtype=torch.bfloat16):
80 | output_ids = model.generate(
81 | input_ids,
82 | images=[(image_tensor_1.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).cuda())],
83 | do_sample=True,
84 | num_beams = 1,
85 | streamer=streamer,
86 | max_new_tokens=4096,
87 | stopping_criteria=[stopping_criteria]
88 | )
89 |
90 |
91 |
92 |
93 | # input_token_len = input_ids.shape[1]
94 | # outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
95 |
96 | # if outputs.endswith(stop_str):
97 | # outputs = outputs[:-len(stop_str)]
98 | # outputs = outputs.strip()
99 |
100 | # print(outputs)
101 |
102 |
103 | if __name__ == "__main__":
104 | parser = argparse.ArgumentParser()
105 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
106 | parser.add_argument("--image-file", type=str, required=True)
107 | # parser.add_argument("--query", type=str, required=True)
108 | parser.add_argument("--conv-mode", type=str, default=None)
109 | args = parser.parse_args()
110 |
111 | eval_model(args)
112 |
--------------------------------------------------------------------------------
/Vary-master/vary/demo/run_qwen_vary.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from transformers import AutoTokenizer, AutoModelForCausalLM
3 | import torch
4 | import os
5 | from vary.utils.conversation import conv_templates, SeparatorStyle
6 | from vary.utils.utils import disable_torch_init
7 | from transformers import CLIPVisionModel, CLIPImageProcessor, StoppingCriteria
8 | from vary.model import *
9 | from vary.utils.utils import KeywordsStoppingCriteria
10 |
11 | from PIL import Image
12 |
13 | import os
14 | import requests
15 | from PIL import Image
16 | from io import BytesIO
17 | from vary.model.plug.blip_process import BlipImageEvalProcessor
18 | from transformers import TextStreamer
19 | from vary.model.plug.transforms import train_transform, test_transform
20 |
21 | DEFAULT_IMAGE_TOKEN = ""
22 | DEFAULT_IMAGE_PATCH_TOKEN = ''
23 | DEFAULT_IM_START_TOKEN = '
'
24 | DEFAULT_IM_END_TOKEN = ''
25 |
26 |
27 | def load_image(image_file):
28 | if image_file.startswith('http') or image_file.startswith('https'):
29 | response = requests.get(image_file)
30 | image = Image.open(BytesIO(response.content)).convert('RGB')
31 | else:
32 | image = Image.open(image_file).convert('RGB')
33 | return image
34 |
35 |
36 | def eval_model(args):
37 | # Model
38 | disable_torch_init()
39 | model_name = os.path.expanduser(args.model_name)
40 |
41 | tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
42 |
43 | model = varyQwenForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=True, device_map='cuda', trust_remote_code=True)
44 |
45 |
46 | model.to(device='cuda', dtype=torch.bfloat16)
47 |
48 |
49 | image_processor = CLIPImageProcessor.from_pretrained("/data/hypertext/ucaswei/cache/vit-large-patch14/vit-large-patch14/", torch_dtype=torch.float16)
50 |
51 | image_processor_high = BlipImageEvalProcessor(image_size=1024)
52 |
53 | use_im_start_end = True
54 |
55 | image_token_len = 256
56 |
57 | qs = 'Provide the ocr results of this image.'
58 |
59 | if use_im_start_end:
60 | qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
61 | else:
62 | qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
63 |
64 |
65 |
66 |
67 | conv_mode = "mpt"
68 | args.conv_mode = conv_mode
69 |
70 | conv = conv_templates[args.conv_mode].copy()
71 | conv.append_message(conv.roles[0], qs)
72 | conv.append_message(conv.roles[1], None)
73 | prompt = conv.get_prompt()
74 |
75 |
76 | inputs = tokenizer([prompt])
77 |
78 |
79 | image = load_image(args.image_file)
80 | image_1 = image.copy()
81 | image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
82 |
83 | image_tensor_1 = image_processor_high(image_1)
84 |
85 | input_ids = torch.as_tensor(inputs.input_ids).cuda()
86 |
87 | # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
88 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
89 | keywords = [stop_str]
90 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
91 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
92 |
93 |
94 | with torch.autocast("cuda", dtype=torch.bfloat16):
95 | output_ids = model.generate(
96 | input_ids,
97 | images=[(image_tensor.unsqueeze(0).half().cuda(), image_tensor_1.unsqueeze(0).half().cuda())],
98 | do_sample=True,
99 | num_beams = 1,
100 | # temperature=0.2,
101 | streamer=streamer,
102 | max_new_tokens=2048,
103 | stopping_criteria=[stopping_criteria]
104 | )
105 |
106 | # print(output_ids)
107 |
108 | # outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
109 |
110 | # # conv.messages[-1][-1] = outputs
111 | # if outputs.endswith(stop_str):
112 | # outputs = outputs[:-len(stop_str)]
113 | # outputs = outputs.strip()
114 |
115 | # print(outputs)
116 |
117 |
118 |
119 | if __name__ == "__main__":
120 | parser = argparse.ArgumentParser()
121 | parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
122 | parser.add_argument("--image-file", type=str, required=True)
123 | parser.add_argument("--conv-mode", type=str, default=None)
124 | args = parser.parse_args()
125 |
126 | eval_model(args)
127 |
--------------------------------------------------------------------------------
/Vary-master/vary/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .vary_opt import varyOPTModel, varyOPTForCausalLM
2 | # from .vary_qwen_vary import varyQwenModel, varyQwenForCausalLM, varyConfig
3 | from .vary_toy_qwen1_8 import varyQwenModel, varyQwenForCausalLM, varyConfig
4 |
5 |
--------------------------------------------------------------------------------
/Vary-master/vary/model/llm/qwen/configuration_qwen.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba Cloud.
2 | #
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | from transformers import PretrainedConfig
7 |
8 |
9 | class QWenConfig(PretrainedConfig):
10 | model_type = "qwen"
11 | keys_to_ignore_at_inference = ["past_key_values"]
12 |
13 | def __init__(
14 | self,
15 | vocab_size=151936,
16 | hidden_size=4096,
17 | num_hidden_layers=32,
18 | num_attention_heads=32,
19 | emb_dropout_prob=0.0,
20 | attn_dropout_prob=0.0,
21 | layer_norm_epsilon=1e-6,
22 | initializer_range=0.02,
23 | max_position_embeddings=8192,
24 | scale_attn_weights=True,
25 | use_cache=True,
26 | bf16=False,
27 | fp16=False,
28 | fp32=False,
29 | kv_channels=128,
30 | rotary_pct=1.0,
31 | rotary_emb_base=10000,
32 | use_dynamic_ntk=True,
33 | use_logn_attn=True,
34 | use_flash_attn="auto",
35 | intermediate_size=22016,
36 | no_bias=True,
37 | tie_word_embeddings=False,
38 | **kwargs,
39 | ):
40 | self.vocab_size = vocab_size
41 | self.hidden_size = hidden_size
42 | self.intermediate_size = intermediate_size
43 | self.num_hidden_layers = num_hidden_layers
44 | self.num_attention_heads = num_attention_heads
45 | self.emb_dropout_prob = emb_dropout_prob
46 | self.attn_dropout_prob = attn_dropout_prob
47 | self.layer_norm_epsilon = layer_norm_epsilon
48 | self.initializer_range = initializer_range
49 | self.scale_attn_weights = scale_attn_weights
50 | self.use_cache = use_cache
51 | self.max_position_embeddings = max_position_embeddings
52 | self.bf16 = bf16
53 | self.fp16 = fp16
54 | self.fp32 = fp32
55 | self.kv_channels = kv_channels
56 | self.rotary_pct = rotary_pct
57 | self.rotary_emb_base = rotary_emb_base
58 | self.use_dynamic_ntk = use_dynamic_ntk
59 | self.use_logn_attn = use_logn_attn
60 | self.use_flash_attn = use_flash_attn
61 | self.no_bias = no_bias
62 | super().__init__(
63 | tie_word_embeddings=tie_word_embeddings,
64 | **kwargs
65 | )
66 |
--------------------------------------------------------------------------------
/Vary-master/vary/model/llm/qwen/qwen_generation_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba Cloud.
2 | #
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """Generation support."""
7 |
8 | from typing import Tuple, List, Union, Iterable
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn.functional as F
13 | from transformers import PreTrainedTokenizer
14 | from transformers import logging
15 | from transformers.generation import LogitsProcessor
16 |
17 | logger = logging.get_logger(__name__)
18 |
19 | # Types.
20 | HistoryType = List[Tuple[str, str]]
21 | TokensType = List[int]
22 | BatchTokensType = List[List[int]]
23 |
24 |
25 | def pad_batch(batch: BatchTokensType, pad_id: int, seq_length: int) -> BatchTokensType:
26 | for tokens in batch:
27 | context_length = len(tokens)
28 | if context_length < seq_length:
29 | tokens.extend([pad_id] * (seq_length - context_length))
30 | return batch
31 |
32 |
33 | def get_ltor_masks_and_position_ids(
34 | data,
35 | eod_token,
36 | reset_position_ids,
37 | reset_attention_mask,
38 | eod_mask_loss,
39 | ):
40 | """Build masks and position id for left to right model."""
41 |
42 | # Extract batch size and sequence length.
43 | micro_batch_size, seq_length = data.size()
44 |
45 | # Attention mask (lower triangular).
46 | if reset_attention_mask:
47 | att_mask_batch = micro_batch_size
48 | else:
49 | att_mask_batch = 1
50 | attention_mask = torch.tril(
51 | torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
52 | ).view(att_mask_batch, 1, seq_length, seq_length)
53 |
54 | # Loss mask.
55 | loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
56 | if eod_mask_loss:
57 | loss_mask[data == eod_token] = 0.0
58 |
59 | # Position ids.
60 | position_ids = torch.arange(seq_length, dtype=torch.long, device=data.device)
61 | position_ids = position_ids.unsqueeze(0).expand_as(data)
62 | # We need to clone as the ids will be modifed based on batch index.
63 | if reset_position_ids:
64 | position_ids = position_ids.clone()
65 |
66 | if reset_position_ids or reset_attention_mask:
67 | # Loop through the batches:
68 | for b in range(micro_batch_size):
69 |
70 | # Find indecies where EOD token is.
71 | eod_index = position_ids[b, data[b] == eod_token]
72 | # Detach indecies from positions if going to modify positions.
73 | if reset_position_ids:
74 | eod_index = eod_index.clone()
75 |
76 | # Loop through EOD indecies:
77 | prev_index = 0
78 | for j in range(eod_index.size()[0]):
79 | i = eod_index[j]
80 | # Mask attention loss.
81 | if reset_attention_mask:
82 | attention_mask[b, 0, (i + 1) :, : (i + 1)] = 0
83 | # Reset positions.
84 | if reset_position_ids:
85 | position_ids[b, (i + 1) :] -= i + 1 - prev_index
86 | prev_index = i + 1
87 |
88 | # Convert attention mask to binary:
89 | attention_mask = attention_mask < 0.5
90 |
91 | return attention_mask, loss_mask, position_ids
92 |
93 |
94 | def get_batch(context_tokens: torch.LongTensor, eod_id: int):
95 | """Generate batch from context tokens."""
96 | # Move to GPU.
97 | tokens = context_tokens.contiguous().to(context_tokens.device)
98 | # Get the attention mask and postition ids.
99 | attention_mask, _, position_ids = get_ltor_masks_and_position_ids(
100 | tokens,
101 | eod_id,
102 | reset_position_ids=False,
103 | reset_attention_mask=False,
104 | eod_mask_loss=False,
105 | )
106 | return tokens, attention_mask, position_ids
107 |
108 |
109 | def get_stop_words_ids(chat_format, tokenizer):
110 | if chat_format == "raw":
111 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]]
112 | elif chat_format == "chatml":
113 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]]
114 | else:
115 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
116 | return stop_words_ids
117 |
118 |
119 | def make_context(
120 | tokenizer: PreTrainedTokenizer,
121 | query: str,
122 | history: List[Tuple[str, str]] = None,
123 | system: str = "",
124 | max_window_size: int = 6144,
125 | chat_format: str = "chatml",
126 | ):
127 | if history is None:
128 | history = []
129 |
130 | if chat_format == "chatml":
131 | im_start, im_end = "<|im_start|>", "<|im_end|>"
132 | im_start_tokens = [tokenizer.im_start_id]
133 | im_end_tokens = [tokenizer.im_end_id]
134 | nl_tokens = tokenizer.encode("\n")
135 |
136 | def _tokenize_str(role, content):
137 | return f"{role}\n{content}", tokenizer.encode(
138 | role, allowed_special=set(tokenizer.IMAGE_ST)
139 | ) + nl_tokens + tokenizer.encode(content, allowed_special=set(tokenizer.IMAGE_ST))
140 |
141 | system_text, system_tokens_part = _tokenize_str("system", system)
142 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens
143 |
144 | raw_text = ""
145 | context_tokens = []
146 |
147 | for turn_query, turn_response in reversed(history):
148 | query_text, query_tokens_part = _tokenize_str("user", turn_query)
149 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens
150 | if turn_response is not None:
151 | response_text, response_tokens_part = _tokenize_str(
152 | "assistant", turn_response
153 | )
154 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens
155 |
156 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens
157 | prev_chat = (
158 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}"
159 | )
160 | else:
161 | next_context_tokens = nl_tokens + query_tokens + nl_tokens
162 | prev_chat = f"\n{im_start}{query_text}{im_end}\n"
163 |
164 | current_context_size = (
165 | len(system_tokens) + len(next_context_tokens) + len(context_tokens)
166 | )
167 | if current_context_size < max_window_size:
168 | context_tokens = next_context_tokens + context_tokens
169 | raw_text = prev_chat + raw_text
170 | else:
171 | break
172 |
173 | context_tokens = system_tokens + context_tokens
174 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text
175 | context_tokens += (
176 | nl_tokens
177 | + im_start_tokens
178 | + _tokenize_str("user", query)[1]
179 | + im_end_tokens
180 | + nl_tokens
181 | + im_start_tokens
182 | + tokenizer.encode("assistant")
183 | + nl_tokens
184 | )
185 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n"
186 |
187 | elif chat_format == "raw":
188 | raw_text = query
189 | context_tokens = tokenizer.encode(raw_text)
190 | else:
191 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
192 |
193 | return raw_text, context_tokens
194 |
195 |
196 | def _decode_default(
197 | tokens: List[int],
198 | *,
199 | stop_words: List[str],
200 | eod_words: List[str],
201 | tokenizer: PreTrainedTokenizer,
202 | raw_text_len: int,
203 | verbose: bool = False,
204 | return_end_reason: bool = False,
205 | errors: str='replace',
206 | ):
207 | trim_decode_tokens = tokenizer.decode(tokens, errors=errors)[raw_text_len:]
208 | if verbose:
209 | print("\nRaw Generate: ", trim_decode_tokens)
210 |
211 | end_reason = f"Gen length {len(tokens)}"
212 | for stop_word in stop_words:
213 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
214 | for eod_word in eod_words:
215 | if eod_word in trim_decode_tokens:
216 | end_reason = f"Gen {eod_word!r}"
217 | trim_decode_tokens = trim_decode_tokens.split(eod_word)[0]
218 | trim_decode_tokens = trim_decode_tokens.strip()
219 | if verbose:
220 | print("\nEnd Reason:", end_reason)
221 | print("\nGenerate: ", trim_decode_tokens)
222 |
223 | if return_end_reason:
224 | return trim_decode_tokens, end_reason
225 | else:
226 | return trim_decode_tokens
227 |
228 |
229 | def _decode_chatml(
230 | tokens: List[int],
231 | *,
232 | stop_words: List[str],
233 | eod_token_ids: List[int],
234 | tokenizer: PreTrainedTokenizer,
235 | raw_text_len: int,
236 | context_length: int,
237 | verbose: bool = False,
238 | return_end_reason: bool = False,
239 | errors: str='replace'
240 | ):
241 | end_reason = f"Gen length {len(tokens)}"
242 | eod_token_idx = context_length
243 | for eod_token_idx in range(context_length, len(tokens)):
244 | if tokens[eod_token_idx] in eod_token_ids:
245 | end_reason = f"Gen {tokenizer.decode([tokens[eod_token_idx]])!r}"
246 | break
247 |
248 | trim_decode_tokens = tokenizer.decode(tokens[:eod_token_idx], errors=errors)[raw_text_len:]
249 | if verbose:
250 | print("\nRaw Generate w/o EOD:", tokenizer.decode(tokens, errors=errors)[raw_text_len:])
251 | print("\nRaw Generate:", trim_decode_tokens)
252 | print("\nEnd Reason:", end_reason)
253 | for stop_word in stop_words:
254 | trim_decode_tokens = trim_decode_tokens.replace(stop_word, "").strip()
255 | trim_decode_tokens = trim_decode_tokens.strip()
256 | if verbose:
257 | print("\nGenerate:", trim_decode_tokens)
258 |
259 | if return_end_reason:
260 | return trim_decode_tokens, end_reason
261 | else:
262 | return trim_decode_tokens
263 |
264 |
265 | def decode_tokens(
266 | tokens: Union[torch.LongTensor, TokensType],
267 | tokenizer: PreTrainedTokenizer,
268 | raw_text_len: int,
269 | context_length: int,
270 | chat_format: str,
271 | verbose: bool = False,
272 | return_end_reason: bool = False,
273 | errors: str="replace",
274 | ) -> str:
275 | if torch.is_tensor(tokens):
276 | tokens = tokens.cpu().numpy().tolist()
277 |
278 | if chat_format == "chatml":
279 | return _decode_chatml(
280 | tokens,
281 | stop_words=[],
282 | eod_token_ids=[tokenizer.im_start_id, tokenizer.im_end_id],
283 | tokenizer=tokenizer,
284 | raw_text_len=raw_text_len,
285 | context_length=context_length,
286 | verbose=verbose,
287 | return_end_reason=return_end_reason,
288 | errors=errors,
289 | )
290 | elif chat_format == "raw":
291 | return _decode_default(
292 | tokens,
293 | stop_words=["<|endoftext|>"],
294 | eod_words=["<|endoftext|>"],
295 | tokenizer=tokenizer,
296 | raw_text_len=raw_text_len,
297 | verbose=verbose,
298 | return_end_reason=return_end_reason,
299 | errors=errors,
300 | )
301 | else:
302 | raise NotImplementedError(f"Unknown chat format {chat_format!r}")
303 |
304 |
305 | class StopWordsLogitsProcessor(LogitsProcessor):
306 | """
307 | :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
308 |
309 | Args:
310 | stop_words_ids (:obj:`List[List[int]]`):
311 | List of list of token ids of stop ids. In order to get the tokens of the words
312 | that should not appear in the generated text, use :obj:`tokenizer(bad_word,
313 | add_prefix_space=True).input_ids`.
314 | eos_token_id (:obj:`int`):
315 | The id of the `end-of-sequence` token.
316 | """
317 |
318 | def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
319 |
320 | if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
321 | raise ValueError(
322 | f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
323 | )
324 | if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
325 | raise ValueError(
326 | f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
327 | )
328 | if any(
329 | any(
330 | (not isinstance(token_id, (int, np.integer)) or token_id < 0)
331 | for token_id in stop_word_ids
332 | )
333 | for stop_word_ids in stop_words_ids
334 | ):
335 | raise ValueError(
336 | f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
337 | )
338 |
339 | self.stop_words_ids = list(
340 | filter(
341 | lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
342 | )
343 | )
344 | self.eos_token_id = eos_token_id
345 | for stop_token_seq in self.stop_words_ids:
346 | assert (
347 | len(stop_token_seq) > 0
348 | ), "Stop words token sequences {} cannot have an empty list".format(
349 | stop_words_ids
350 | )
351 |
352 | def __call__(
353 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor
354 | ) -> torch.FloatTensor:
355 | stopped_samples = self._calc_stopped_samples(input_ids)
356 | for i, should_stop in enumerate(stopped_samples):
357 | if should_stop:
358 | scores[i, self.eos_token_id] = float(2**15)
359 | return scores
360 |
361 | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
362 | if len(tokens) == 0:
363 | # if bad word tokens is just one token always ban it
364 | return True
365 | elif len(tokens) > len(prev_tokens):
366 | # if bad word tokens are longer then prev input_ids they can't be equal
367 | return False
368 | elif prev_tokens[-len(tokens) :].tolist() == tokens:
369 | # if tokens match
370 | return True
371 | else:
372 | return False
373 |
374 | def _calc_stopped_samples(self, prev_input_ids: Iterable[int]) -> Iterable[int]:
375 | stopped_samples = []
376 | for prev_input_ids_slice in prev_input_ids:
377 | match = False
378 | for stop_token_seq in self.stop_words_ids:
379 | if self._tokens_match(prev_input_ids_slice, stop_token_seq):
380 | # if tokens do not match continue
381 | match = True
382 | break
383 | stopped_samples.append(match)
384 |
385 | return stopped_samples
386 |
387 |
388 | def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
389 | """This function has been mostly taken from huggingface conversational
390 | ai code at
391 | https://medium.com/huggingface/how-to-build-a-state-of-the-art-
392 | conversational-ai-with-transfer-learning-2d818ac26313"""
393 |
394 | if top_k > 0:
395 | # Remove all tokens with a probability less than the
396 | # last token of the top-k
397 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
398 | logits[indices_to_remove] = filter_value
399 |
400 | if top_p > 0.0:
401 | # Cconvert to 1D
402 | sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
403 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
404 |
405 | # Remove tokens with cumulative probability above the threshold
406 | sorted_indices_to_remove = cumulative_probs > top_p
407 | # Shift the indices to the right to keep also the first token
408 | # above the threshold
409 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
410 | sorted_indices_to_remove[..., 0] = 0
411 | for i in range(sorted_indices.size(0)):
412 | indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
413 | logits[i][indices_to_remove] = filter_value
414 |
415 | return logits
416 |
417 |
418 | def switch(val1, val2, boolean):
419 | boolean = boolean.type_as(val1)
420 | return (1 - boolean) * val1 + boolean * val2
421 |
--------------------------------------------------------------------------------
/Vary-master/vary/model/llm/qwen/tokenization_qwen.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba Cloud.
2 | #
3 | # This source code is licensed under the license found in the
4 | # LICENSE file in the root directory of this source tree.
5 |
6 | """Tokenization classes for QWen."""
7 |
8 | import base64
9 | import logging
10 | import os
11 | import unicodedata
12 | from typing import Collection, Dict, List, Set, Tuple, Union
13 |
14 | import tiktoken
15 | from transformers import PreTrainedTokenizer, AddedToken
16 |
17 | logger = logging.getLogger(__name__)
18 |
19 |
20 | VOCAB_FILES_NAMES = {"vocab_file": "qwen.tiktoken"}
21 |
22 | PAT_STR = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
23 | ENDOFTEXT = "<|endoftext|>"
24 | IMSTART = "<|im_start|>"
25 | IMEND = "<|im_end|>"
26 | # as the default behavior is changed to allow special tokens in
27 | # regular texts, the surface forms of special tokens need to be
28 | # as different as possible to minimize the impact
29 | EXTRAS = tuple((f"<|extra_{i}|>" for i in range(205)))
30 | SPECIAL_TOKENS = (
31 | ENDOFTEXT,
32 | IMSTART,
33 | IMEND,
34 | ) + EXTRAS
35 |
36 |
37 | def _load_tiktoken_bpe(tiktoken_bpe_file: str) -> Dict[bytes, int]:
38 | with open(tiktoken_bpe_file, "rb") as f:
39 | contents = f.read()
40 | return {
41 | base64.b64decode(token): int(rank)
42 | for token, rank in (line.split() for line in contents.splitlines() if line)
43 | }
44 |
45 | class QWenTokenizer(PreTrainedTokenizer):
46 | """QWen tokenizer."""
47 |
48 | vocab_files_names = VOCAB_FILES_NAMES
49 |
50 | def __init__(
51 | self,
52 | vocab_file,
53 | errors="replace",
54 | **kwargs,
55 | ):
56 | super().__init__(**kwargs)
57 |
58 | self.errors = errors # how to handle errors in decoding
59 |
60 | self.mergeable_ranks = _load_tiktoken_bpe(vocab_file) # type: dict[bytes, int]
61 | self.special_tokens = {
62 | token: index
63 | for index, token in enumerate(
64 | SPECIAL_TOKENS, start=len(self.mergeable_ranks)
65 | )
66 | }
67 |
68 | enc = tiktoken.Encoding(
69 | "Qwen",
70 | pat_str=PAT_STR,
71 | mergeable_ranks=self.mergeable_ranks,
72 | special_tokens=self.special_tokens,
73 | )
74 | assert (
75 | len(self.mergeable_ranks) + len(self.special_tokens) == enc.n_vocab
76 | ), f"{len(self.mergeable_ranks) + len(self.special_tokens)} != {enc.n_vocab} in encoding"
77 |
78 | self.decoder = {
79 | v: k for k, v in self.mergeable_ranks.items()
80 | } # type: dict[int, bytes|str]
81 | self.decoder.update({v: k for k, v in self.special_tokens.items()})
82 |
83 | self.tokenizer = enc # type: tiktoken.Encoding
84 |
85 | self.eod_id = self.tokenizer.eot_token
86 | self.im_start_id = self.special_tokens[IMSTART]
87 | self.im_end_id = self.special_tokens[IMEND]
88 |
89 | def __len__(self) -> int:
90 | return self.tokenizer.n_vocab
91 |
92 | def get_vocab(self) -> Dict[bytes, int]:
93 | return self.mergeable_ranks
94 |
95 | def convert_tokens_to_ids(
96 | self, tokens: Union[bytes, str, List[Union[bytes, str]]]
97 | ) -> List[int]:
98 | ids = []
99 | if isinstance(tokens, (str, bytes)):
100 | if tokens in self.special_tokens:
101 | return self.special_tokens[tokens]
102 | else:
103 | return self.mergeable_ranks.get(tokens)
104 | for token in tokens:
105 | if token in self.special_tokens:
106 | ids.append(self.special_tokens[token])
107 | else:
108 | ids.append(self.mergeable_ranks.get(token))
109 | return ids
110 |
111 | def _add_tokens(self, new_tokens: Union[List[str], List[AddedToken]], special_tokens: bool = False) -> int:
112 | if not special_tokens and new_tokens:
113 | raise ValueError('Adding regular tokens is not supported')
114 | for token in new_tokens:
115 | surface_form = token.content if isinstance(token, AddedToken) else token
116 | if surface_form not in SPECIAL_TOKENS:
117 | raise ValueError('Adding unknown special tokens is not supported')
118 | return 0
119 |
120 | def save_vocabulary(self, save_directory: str, **kwargs) -> Tuple[str]:
121 | """
122 | Save only the vocabulary of the tokenizer (vocabulary).
123 |
124 | Returns:
125 | `Tuple(str)`: Paths to the files saved.
126 | """
127 | file_path = os.path.join(save_directory, "qwen.tiktoken")
128 | with open(file_path, "w", encoding="utf8") as w:
129 | for k, v in self.mergeable_ranks.items():
130 | line = base64.b64encode(k).decode("utf8") + " " + str(v) + "\n"
131 | w.write(line)
132 | return (file_path,)
133 |
134 | def tokenize(
135 | self,
136 | text: str,
137 | allowed_special: Union[Set, str] = "all",
138 | disallowed_special: Union[Collection, str] = (),
139 | **kwargs,
140 | ) -> List[Union[bytes, str]]:
141 | """
142 | Converts a string in a sequence of tokens.
143 |
144 | Args:
145 | text (`str`):
146 | The sequence to be encoded.
147 | allowed_special (`Literal["all"]` or `set`):
148 | The surface forms of the tokens to be encoded as special tokens in regular texts.
149 | Default to "all".
150 | disallowed_special (`Literal["all"]` or `Collection`):
151 | The surface forms of the tokens that should not be in regular texts and trigger errors.
152 | Default to an empty tuple.
153 |
154 | kwargs (additional keyword arguments, *optional*):
155 | Will be passed to the underlying model specific encode method.
156 |
157 | Returns:
158 | `List[bytes|str]`: The list of tokens.
159 | """
160 | tokens = []
161 | text = unicodedata.normalize("NFC", text)
162 |
163 | # this implementation takes a detour: text -> token id -> token surface forms
164 | for t in self.tokenizer.encode(
165 | text, allowed_special=allowed_special, disallowed_special=disallowed_special
166 | ):
167 | tokens.append(self.decoder[t])
168 | return tokens
169 |
170 | def convert_tokens_to_string(self, tokens: List[Union[bytes, str]]) -> str:
171 | """
172 | Converts a sequence of tokens in a single string.
173 | """
174 | text = ""
175 | temp = b""
176 | for t in tokens:
177 | if isinstance(t, str):
178 | if temp:
179 | text += temp.decode("utf-8", errors=self.errors)
180 | temp = b""
181 | text += t
182 | elif isinstance(t, bytes):
183 | temp += t
184 | else:
185 | raise TypeError("token should only be of type types or str")
186 | if temp:
187 | text += temp.decode("utf-8", errors=self.errors)
188 | return text
189 |
190 | @property
191 | def vocab_size(self):
192 | return self.tokenizer.n_vocab
193 |
194 | def _convert_id_to_token(self, index: int) -> Union[bytes, str]:
195 | """Converts an id to a token, special tokens included"""
196 | if index in self.decoder:
197 | return self.decoder[index]
198 | raise ValueError("unknown ids")
199 |
200 | def _convert_token_to_id(self, token: Union[bytes, str]) -> int:
201 | """Converts a token to an id using the vocab, special tokens included"""
202 | if token in self.special_tokens:
203 | return self.special_tokens[token]
204 | if token in self.mergeable_ranks:
205 | return self.mergeable_ranks[token]
206 | raise ValueError("unknown token")
207 |
208 | def _tokenize(self, text: str, **kwargs):
209 | """
210 | Converts a string in a sequence of tokens (string), using the tokenizer. Split in words for word-based
211 | vocabulary or sub-words for sub-word-based vocabularies (BPE/SentencePieces/WordPieces).
212 |
213 | Do NOT take care of added tokens.
214 | """
215 | raise NotImplementedError
216 |
217 | def _decode(
218 | self,
219 | token_ids: Union[int, List[int]],
220 | skip_special_tokens: bool = False,
221 | errors: str = None,
222 | **kwargs,
223 | ) -> str:
224 | if isinstance(token_ids, int):
225 | token_ids = [token_ids]
226 | if skip_special_tokens:
227 | token_ids = [i for i in token_ids if i < self.eod_id]
228 | return self.tokenizer.decode(token_ids, errors=errors or self.errors)
229 |
--------------------------------------------------------------------------------
/Vary-master/vary/model/plug/blip_process.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2022, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 | """
7 |
8 | import cv2
9 | import numpy as np
10 |
11 | import torch
12 |
13 | # from omegaconf import OmegaConf
14 | from torchvision import transforms
15 | from torchvision.transforms.functional import InterpolationMode
16 | from PIL import Image
17 |
18 | class BaseProcessor:
19 | def __init__(self):
20 | self.transform = lambda x: x
21 | return
22 |
23 | def __call__(self, item):
24 | return self.transform(item)
25 |
26 | # @classmethod
27 | # def from_config(cls, cfg=None):
28 | # return cls()
29 |
30 | # def build(self, **kwargs):
31 | # cfg = OmegaConf.create(kwargs)
32 |
33 | # return self.from_config(cfg)
34 |
35 | class BlipImageBaseProcessor(BaseProcessor):
36 | def __init__(self, mean=None, std=None):
37 | if mean is None:
38 | mean = (0.48145466, 0.4578275, 0.40821073)
39 | if std is None:
40 | std = (0.26862954, 0.26130258, 0.27577711)
41 |
42 | self.normalize = transforms.Normalize(mean, std)
43 |
44 |
45 | ## aug functions
46 | def identity_func(img):
47 | return img
48 |
49 |
50 | def autocontrast_func(img, cutoff=0):
51 | """
52 | same output as PIL.ImageOps.autocontrast
53 | """
54 | n_bins = 256
55 |
56 | def tune_channel(ch):
57 | n = ch.size
58 | cut = cutoff * n // 100
59 | if cut == 0:
60 | high, low = ch.max(), ch.min()
61 | else:
62 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
63 | low = np.argwhere(np.cumsum(hist) > cut)
64 | low = 0 if low.shape[0] == 0 else low[0]
65 | high = np.argwhere(np.cumsum(hist[::-1]) > cut)
66 | high = n_bins - 1 if high.shape[0] == 0 else n_bins - 1 - high[0]
67 | if high <= low:
68 | table = np.arange(n_bins)
69 | else:
70 | scale = (n_bins - 1) / (high - low)
71 | offset = -low * scale
72 | table = np.arange(n_bins) * scale + offset
73 | table[table < 0] = 0
74 | table[table > n_bins - 1] = n_bins - 1
75 | table = table.clip(0, 255).astype(np.uint8)
76 | return table[ch]
77 |
78 | channels = [tune_channel(ch) for ch in cv2.split(img)]
79 | out = cv2.merge(channels)
80 | return out
81 |
82 |
83 | def equalize_func(img):
84 | """
85 | same output as PIL.ImageOps.equalize
86 | PIL's implementation is different from cv2.equalize
87 | """
88 | n_bins = 256
89 |
90 | def tune_channel(ch):
91 | hist = cv2.calcHist([ch], [0], None, [n_bins], [0, n_bins])
92 | non_zero_hist = hist[hist != 0].reshape(-1)
93 | step = np.sum(non_zero_hist[:-1]) // (n_bins - 1)
94 | if step == 0:
95 | return ch
96 | n = np.empty_like(hist)
97 | n[0] = step // 2
98 | n[1:] = hist[:-1]
99 | table = (np.cumsum(n) // step).clip(0, 255).astype(np.uint8)
100 | return table[ch]
101 |
102 | channels = [tune_channel(ch) for ch in cv2.split(img)]
103 | out = cv2.merge(channels)
104 | return out
105 |
106 |
107 | def rotate_func(img, degree, fill=(0, 0, 0)):
108 | """
109 | like PIL, rotate by degree, not radians
110 | """
111 | H, W = img.shape[0], img.shape[1]
112 | center = W / 2, H / 2
113 | M = cv2.getRotationMatrix2D(center, degree, 1)
114 | out = cv2.warpAffine(img, M, (W, H), borderValue=fill)
115 | return out
116 |
117 |
118 | def solarize_func(img, thresh=128):
119 | """
120 | same output as PIL.ImageOps.posterize
121 | """
122 | table = np.array([el if el < thresh else 255 - el for el in range(256)])
123 | table = table.clip(0, 255).astype(np.uint8)
124 | out = table[img]
125 | return out
126 |
127 |
128 | def color_func(img, factor):
129 | """
130 | same output as PIL.ImageEnhance.Color
131 | """
132 | ## implementation according to PIL definition, quite slow
133 | # degenerate = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[:, :, np.newaxis]
134 | # out = blend(degenerate, img, factor)
135 | # M = (
136 | # np.eye(3) * factor
137 | # + np.float32([0.114, 0.587, 0.299]).reshape(3, 1) * (1. - factor)
138 | # )[np.newaxis, np.newaxis, :]
139 | M = np.float32(
140 | [[0.886, -0.114, -0.114], [-0.587, 0.413, -0.587], [-0.299, -0.299, 0.701]]
141 | ) * factor + np.float32([[0.114], [0.587], [0.299]])
142 | out = np.matmul(img, M).clip(0, 255).astype(np.uint8)
143 | return out
144 |
145 |
146 | def contrast_func(img, factor):
147 | """
148 | same output as PIL.ImageEnhance.Contrast
149 | """
150 | mean = np.sum(np.mean(img, axis=(0, 1)) * np.array([0.114, 0.587, 0.299]))
151 | table = (
152 | np.array([(el - mean) * factor + mean for el in range(256)])
153 | .clip(0, 255)
154 | .astype(np.uint8)
155 | )
156 | out = table[img]
157 | return out
158 |
159 |
160 | def brightness_func(img, factor):
161 | """
162 | same output as PIL.ImageEnhance.Contrast
163 | """
164 | table = (np.arange(256, dtype=np.float32) * factor).clip(0, 255).astype(np.uint8)
165 | out = table[img]
166 | return out
167 |
168 |
169 | def sharpness_func(img, factor):
170 | """
171 | The differences the this result and PIL are all on the 4 boundaries, the center
172 | areas are same
173 | """
174 | kernel = np.ones((3, 3), dtype=np.float32)
175 | kernel[1][1] = 5
176 | kernel /= 13
177 | degenerate = cv2.filter2D(img, -1, kernel)
178 | if factor == 0.0:
179 | out = degenerate
180 | elif factor == 1.0:
181 | out = img
182 | else:
183 | out = img.astype(np.float32)
184 | degenerate = degenerate.astype(np.float32)[1:-1, 1:-1, :]
185 | out[1:-1, 1:-1, :] = degenerate + factor * (out[1:-1, 1:-1, :] - degenerate)
186 | out = out.astype(np.uint8)
187 | return out
188 |
189 |
190 | def shear_x_func(img, factor, fill=(0, 0, 0)):
191 | H, W = img.shape[0], img.shape[1]
192 | M = np.float32([[1, factor, 0], [0, 1, 0]])
193 | out = cv2.warpAffine(
194 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
195 | ).astype(np.uint8)
196 | return out
197 |
198 |
199 | def translate_x_func(img, offset, fill=(0, 0, 0)):
200 | """
201 | same output as PIL.Image.transform
202 | """
203 | H, W = img.shape[0], img.shape[1]
204 | M = np.float32([[1, 0, -offset], [0, 1, 0]])
205 | out = cv2.warpAffine(
206 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
207 | ).astype(np.uint8)
208 | return out
209 |
210 |
211 | def translate_y_func(img, offset, fill=(0, 0, 0)):
212 | """
213 | same output as PIL.Image.transform
214 | """
215 | H, W = img.shape[0], img.shape[1]
216 | M = np.float32([[1, 0, 0], [0, 1, -offset]])
217 | out = cv2.warpAffine(
218 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
219 | ).astype(np.uint8)
220 | return out
221 |
222 |
223 | def posterize_func(img, bits):
224 | """
225 | same output as PIL.ImageOps.posterize
226 | """
227 | out = np.bitwise_and(img, np.uint8(255 << (8 - bits)))
228 | return out
229 |
230 |
231 | def shear_y_func(img, factor, fill=(0, 0, 0)):
232 | H, W = img.shape[0], img.shape[1]
233 | M = np.float32([[1, 0, 0], [factor, 1, 0]])
234 | out = cv2.warpAffine(
235 | img, M, (W, H), borderValue=fill, flags=cv2.INTER_LINEAR
236 | ).astype(np.uint8)
237 | return out
238 |
239 |
240 | def cutout_func(img, pad_size, replace=(0, 0, 0)):
241 | replace = np.array(replace, dtype=np.uint8)
242 | H, W = img.shape[0], img.shape[1]
243 | rh, rw = np.random.random(2)
244 | pad_size = pad_size // 2
245 | ch, cw = int(rh * H), int(rw * W)
246 | x1, x2 = max(ch - pad_size, 0), min(ch + pad_size, H)
247 | y1, y2 = max(cw - pad_size, 0), min(cw + pad_size, W)
248 | out = img.copy()
249 | out[x1:x2, y1:y2, :] = replace
250 | return out
251 |
252 |
253 | ### level to args
254 | def enhance_level_to_args(MAX_LEVEL):
255 | def level_to_args(level):
256 | return ((level / MAX_LEVEL) * 1.8 + 0.1,)
257 |
258 | return level_to_args
259 |
260 |
261 | def shear_level_to_args(MAX_LEVEL, replace_value):
262 | def level_to_args(level):
263 | level = (level / MAX_LEVEL) * 0.3
264 | if np.random.random() > 0.5:
265 | level = -level
266 | return (level, replace_value)
267 |
268 | return level_to_args
269 |
270 |
271 | def translate_level_to_args(translate_const, MAX_LEVEL, replace_value):
272 | def level_to_args(level):
273 | level = (level / MAX_LEVEL) * float(translate_const)
274 | if np.random.random() > 0.5:
275 | level = -level
276 | return (level, replace_value)
277 |
278 | return level_to_args
279 |
280 |
281 | def cutout_level_to_args(cutout_const, MAX_LEVEL, replace_value):
282 | def level_to_args(level):
283 | level = int((level / MAX_LEVEL) * cutout_const)
284 | return (level, replace_value)
285 |
286 | return level_to_args
287 |
288 |
289 | def solarize_level_to_args(MAX_LEVEL):
290 | def level_to_args(level):
291 | level = int((level / MAX_LEVEL) * 256)
292 | return (level,)
293 |
294 | return level_to_args
295 |
296 |
297 | def none_level_to_args(level):
298 | return ()
299 |
300 |
301 | def posterize_level_to_args(MAX_LEVEL):
302 | def level_to_args(level):
303 | level = int((level / MAX_LEVEL) * 4)
304 | return (level,)
305 |
306 | return level_to_args
307 |
308 |
309 | def rotate_level_to_args(MAX_LEVEL, replace_value):
310 | def level_to_args(level):
311 | level = (level / MAX_LEVEL) * 30
312 | if np.random.random() < 0.5:
313 | level = -level
314 | return (level, replace_value)
315 |
316 | return level_to_args
317 |
318 |
319 | func_dict = {
320 | "Identity": identity_func,
321 | "AutoContrast": autocontrast_func,
322 | "Equalize": equalize_func,
323 | "Rotate": rotate_func,
324 | "Solarize": solarize_func,
325 | "Color": color_func,
326 | "Contrast": contrast_func,
327 | "Brightness": brightness_func,
328 | "Sharpness": sharpness_func,
329 | "ShearX": shear_x_func,
330 | "TranslateX": translate_x_func,
331 | "TranslateY": translate_y_func,
332 | "Posterize": posterize_func,
333 | "ShearY": shear_y_func,
334 | }
335 |
336 | translate_const = 10
337 | MAX_LEVEL = 10
338 | replace_value = (128, 128, 128)
339 | arg_dict = {
340 | "Identity": none_level_to_args,
341 | "AutoContrast": none_level_to_args,
342 | "Equalize": none_level_to_args,
343 | "Rotate": rotate_level_to_args(MAX_LEVEL, replace_value),
344 | "Solarize": solarize_level_to_args(MAX_LEVEL),
345 | "Color": enhance_level_to_args(MAX_LEVEL),
346 | "Contrast": enhance_level_to_args(MAX_LEVEL),
347 | "Brightness": enhance_level_to_args(MAX_LEVEL),
348 | "Sharpness": enhance_level_to_args(MAX_LEVEL),
349 | "ShearX": shear_level_to_args(MAX_LEVEL, replace_value),
350 | "TranslateX": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
351 | "TranslateY": translate_level_to_args(translate_const, MAX_LEVEL, replace_value),
352 | "Posterize": posterize_level_to_args(MAX_LEVEL),
353 | "ShearY": shear_level_to_args(MAX_LEVEL, replace_value),
354 | }
355 |
356 |
357 | class RandomAugment(object):
358 | def __init__(self, N=2, M=10, isPIL=False, augs=[]):
359 | self.N = N
360 | self.M = M
361 | self.isPIL = isPIL
362 | if augs:
363 | self.augs = augs
364 | else:
365 | self.augs = list(arg_dict.keys())
366 |
367 | def get_random_ops(self):
368 | sampled_ops = np.random.choice(self.augs, self.N)
369 | return [(op, 0.5, self.M) for op in sampled_ops]
370 |
371 | def __call__(self, img):
372 | if self.isPIL:
373 | img = np.array(img)
374 | ops = self.get_random_ops()
375 | for name, prob, level in ops:
376 | if np.random.random() > prob:
377 | continue
378 | args = arg_dict[name](level)
379 | img = func_dict[name](img, *args)
380 | return img
381 |
382 |
383 | class VideoRandomAugment(object):
384 | def __init__(self, N=2, M=10, p=0.0, tensor_in_tensor_out=True, augs=[]):
385 | self.N = N
386 | self.M = M
387 | self.p = p
388 | self.tensor_in_tensor_out = tensor_in_tensor_out
389 | if augs:
390 | self.augs = augs
391 | else:
392 | self.augs = list(arg_dict.keys())
393 |
394 | def get_random_ops(self):
395 | sampled_ops = np.random.choice(self.augs, self.N, replace=False)
396 | return [(op, self.M) for op in sampled_ops]
397 |
398 | def __call__(self, frames):
399 | assert (
400 | frames.shape[-1] == 3
401 | ), "Expecting last dimension for 3-channels RGB (b, h, w, c)."
402 |
403 | if self.tensor_in_tensor_out:
404 | frames = frames.numpy().astype(np.uint8)
405 |
406 | num_frames = frames.shape[0]
407 |
408 | ops = num_frames * [self.get_random_ops()]
409 | apply_or_not = num_frames * [np.random.random(size=self.N) > self.p]
410 |
411 | frames = torch.stack(
412 | list(map(self._aug, frames, ops, apply_or_not)), dim=0
413 | ).float()
414 |
415 | return frames
416 |
417 | def _aug(self, img, ops, apply_or_not):
418 | for i, (name, level) in enumerate(ops):
419 | if not apply_or_not[i]:
420 | continue
421 | args = arg_dict[name](level)
422 | img = func_dict[name](img, *args)
423 | return torch.from_numpy(img)
424 |
425 |
426 | # if __name__ == "__main__":
427 | # a = RandomAugment()
428 | # img = np.random.randn(32, 32, 3)
429 | # a(img)
430 |
431 |
432 |
433 |
434 |
435 |
436 | class BlipImageTrainProcessor(BlipImageBaseProcessor):
437 | def __init__(
438 | self, image_size=384, mean=None, std=None, min_scale=0.5, max_scale=1.0
439 | ):
440 | super().__init__(mean=mean, std=std)
441 |
442 | self.transform = transforms.Compose(
443 | [
444 | transforms.RandomResizedCrop(
445 | image_size,
446 | scale=(min_scale, max_scale),
447 | interpolation=InterpolationMode.BICUBIC,
448 | ),
449 | # transforms.RandomHorizontalFlip(),
450 | RandomAugment(
451 | 2,
452 | 5,
453 | isPIL=True,
454 | augs=[
455 | "Identity",
456 | # "AutoContrast",
457 | "Brightness",
458 | "Sharpness",
459 | "Equalize",
460 | # "ShearX",
461 | # "ShearY",
462 | # "TranslateX",
463 | # "TranslateY",
464 | # "Rotate",
465 | ],
466 | ),
467 | transforms.ToTensor(),
468 | self.normalize,
469 | ]
470 | )
471 |
472 | def __call__(self, item):
473 | return self.transform(item)
474 |
475 |
476 | class BlipImageEvalProcessor(BlipImageBaseProcessor):
477 | def __init__(self, image_size=384, mean=None, std=None):
478 | super().__init__(mean=mean, std=std)
479 |
480 | self.transform = transforms.Compose(
481 | [
482 | transforms.Resize(
483 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC
484 | ),
485 | transforms.ToTensor(),
486 | self.normalize,
487 | ]
488 | )
489 |
490 | def __call__(self, item):
491 | return self.transform(item)
492 |
493 |
494 | # if __name__ == "__main__":
495 | # a = BlipImageTrainProcessor(image_size=1024)
496 | # # img = np.random.randn(1024, 1024, 3)
497 | # # x = torch.zeros(1024, 1024, 3)
498 | # x = Image.open("/data/codes/vary-main/log/serve_images/2023-05-23/a2a783d89ede819cdeae943a2199ad3d.jpg").convert("RGB")
499 | # print(x.size)
500 | # y = a(x)
501 |
502 | # print(y.size())
503 |
--------------------------------------------------------------------------------
/Vary-master/vary/model/plug/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) Meta Platforms, Inc. and affiliates.
3 |
4 | This source code is licensed under the MIT license found in the
5 | LICENSE file in the root directory of this source tree.
6 | """
7 | # Implements image augmentation
8 |
9 | import albumentations as alb
10 | from albumentations.pytorch import ToTensorV2
11 | import cv2
12 | import numpy as np
13 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
14 |
15 |
16 | def alb_wrapper(transform):
17 | def f(im):
18 | return transform(image=np.asarray(im))["image"]
19 |
20 | return f
21 |
22 |
23 | class Erosion(alb.ImageOnlyTransform):
24 | """
25 | Apply erosion operation to an image.
26 |
27 | Erosion is a morphological operation that shrinks the white regions in a binary image.
28 |
29 | Args:
30 | scale (int or tuple/list of int): The scale or range for the size of the erosion kernel.
31 | If an integer is provided, a square kernel of that size will be used.
32 | If a tuple or list is provided, it should contain two integers representing the minimum
33 | and maximum sizes for the erosion kernel.
34 | always_apply (bool, optional): Whether to always apply this transformation. Default is False.
35 | p (float, optional): The probability of applying this transformation. Default is 0.5.
36 |
37 | Returns:
38 | numpy.ndarray: The transformed image.
39 | """
40 |
41 | def __init__(self, scale, always_apply=False, p=0.5):
42 | super().__init__(always_apply=always_apply, p=p)
43 | if type(scale) is tuple or type(scale) is list:
44 | assert len(scale) == 2
45 | self.scale = scale
46 | else:
47 | self.scale = (scale, scale)
48 |
49 | def apply(self, img, **params):
50 | kernel = cv2.getStructuringElement(
51 | cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
52 | )
53 | img = cv2.erode(img, kernel, iterations=1)
54 | return img
55 |
56 |
57 | class Dilation(alb.ImageOnlyTransform):
58 | """
59 | Apply dilation operation to an image.
60 |
61 | Dilation is a morphological operation that expands the white regions in a binary image.
62 |
63 | Args:
64 | scale (int or tuple/list of int): The scale or range for the size of the dilation kernel.
65 | If an integer is provided, a square kernel of that size will be used.
66 | If a tuple or list is provided, it should contain two integers representing the minimum
67 | and maximum sizes for the dilation kernel.
68 | always_apply (bool, optional): Whether to always apply this transformation. Default is False.
69 | p (float, optional): The probability of applying this transformation. Default is 0.5.
70 |
71 | Returns:
72 | numpy.ndarray: The transformed image.
73 | """
74 |
75 | def __init__(self, scale, always_apply=False, p=0.5):
76 | super().__init__(always_apply=always_apply, p=p)
77 | if type(scale) is tuple or type(scale) is list:
78 | assert len(scale) == 2
79 | self.scale = scale
80 | else:
81 | self.scale = (scale, scale)
82 |
83 | def apply(self, img, **params):
84 | kernel = cv2.getStructuringElement(
85 | cv2.MORPH_ELLIPSE, tuple(np.random.randint(self.scale[0], self.scale[1], 2))
86 | )
87 | img = cv2.dilate(img, kernel, iterations=1)
88 | return img
89 |
90 |
91 | class Bitmap(alb.ImageOnlyTransform):
92 | """
93 | Apply a bitmap-style transformation to an image.
94 |
95 | This transformation replaces all pixel values below a certain threshold with a specified value.
96 |
97 | Args:
98 | value (int, optional): The value to replace pixels below the threshold with. Default is 0.
99 | lower (int, optional): The threshold value below which pixels will be replaced. Default is 200.
100 | always_apply (bool, optional): Whether to always apply this transformation. Default is False.
101 | p (float, optional): The probability of applying this transformation. Default is 0.5.
102 |
103 | Returns:
104 | numpy.ndarray: The transformed image.
105 | """
106 |
107 | def __init__(self, value=0, lower=200, always_apply=False, p=0.5):
108 | super().__init__(always_apply=always_apply, p=p)
109 | self.lower = lower
110 | self.value = value
111 |
112 | def apply(self, img, **params):
113 | img = img.copy()
114 | img[img < self.lower] = self.value
115 | return img
116 |
117 |
118 | train_transform = alb_wrapper(
119 | alb.Compose(
120 | [
121 | Bitmap(p=0),
122 | alb.OneOf([Erosion((2, 3)), Dilation((2, 3))], p=0.02),
123 | alb.Affine(shear={"x": (0, 3), "y": (-3, 0)}, cval=(255, 255, 255), p=0.03),
124 | alb.ShiftScaleRotate(
125 | shift_limit_x=(0, 0.04),
126 | shift_limit_y=(0, 0.03),
127 | scale_limit=(-0.15, 0.03),
128 | rotate_limit=2,
129 | border_mode=0,
130 | interpolation=2,
131 | value=(255, 255, 255),
132 | p=0.03,
133 | ),
134 | alb.GridDistortion(
135 | distort_limit=0.05,
136 | border_mode=0,
137 | interpolation=2,
138 | value=(255, 255, 255),
139 | p=0.04,
140 | ),
141 | alb.Compose(
142 | [
143 | alb.Affine(
144 | translate_px=(0, 5), always_apply=True, cval=(255, 255, 255)
145 | ),
146 | alb.ElasticTransform(
147 | p=1,
148 | alpha=50,
149 | sigma=120 * 0.1,
150 | alpha_affine=120 * 0.01,
151 | border_mode=0,
152 | value=(255, 255, 255),
153 | ),
154 | ],
155 | p=0.04,
156 | ),
157 | alb.RandomBrightnessContrast(0.1, 0.1, True, p=0.03),
158 | alb.ImageCompression(95, p=0.07),
159 | alb.GaussNoise(20, p=0.08),
160 | alb.GaussianBlur((3, 3), p=0.03),
161 | alb.Resize(1024, 1024),
162 | alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
163 | ToTensorV2(),
164 | ]
165 | )
166 | )
167 | test_transform = alb_wrapper(
168 | alb.Compose(
169 | [
170 | alb.Resize(1024, 1024),
171 | alb.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
172 | ToTensorV2(),
173 | ]
174 | )
175 | )
176 |
177 |
178 |
179 |
180 | # if __name__ == '__main__':
181 | # from PIL import Image
182 | # image = Image.open('/data/hypertext/ucaswei/codes_new/show/49.jpg').convert('RGB')
183 | # # image = np.array(image)
184 | # for i in range(100):
185 | # image1 = train_transform(image)
186 | # image1 = Image.fromarray(np.uint8(image1))
187 | # image1.save('/data/hypertext/ucaswei/codes_new/aug/' + str(i) + '.jpg')
188 | # mm_projector_1 = nn.Linear(1024, 256)
189 |
190 | # x = torch.zeros(2, 3, 1024, 1024)
191 |
192 | # with torch.no_grad():
193 | # y = model(x)
194 | # print(y.shape)
195 | # y = mm_projector_1(y.permute(0,2,1))
196 | # print(y.shape)
197 | # print(y.permute(0,2,1).shape)
--------------------------------------------------------------------------------
/Vary-master/vary/model/vary_opt.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | from torch.nn import CrossEntropyLoss
22 |
23 | from transformers import AutoConfig, AutoModelForCausalLM, \
24 | LlamaConfig, LlamaModel, LlamaForCausalLM, \
25 | CLIPVisionModel, CLIPImageProcessor
26 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
27 |
28 | from vary.utils.constants import *
29 |
30 | from vary.model.plug.blip_process import BlipImageEvalProcessor
31 |
32 | from vary.model.vision_encoder.sam import build_sam_vit_b
33 |
34 |
35 | from transformers import OPTConfig, OPTModel, OPTForCausalLM
36 |
37 | from vary.model.plug.transforms import train_transform, test_transform
38 |
39 |
40 |
41 | class varyConfig(OPTConfig):
42 | model_type = "vary"
43 |
44 |
45 | class varyOPTModel(OPTModel):
46 | config_class = varyConfig
47 |
48 | def __init__(self, config: OPTConfig):
49 | super(varyOPTModel, self).__init__(config)
50 |
51 |
52 | self.vision_tower = build_sam_vit_b()
53 |
54 | self.mm_projector = nn.Linear(1024, 768)
55 |
56 | def initialize_vision_modules(
57 | self,
58 | vision_tower,
59 | pretrained_stage1_model=None,
60 | freeze_vision_tower=False,
61 | use_im_start_end=False,
62 | vision_select_layer=-1,
63 | dtype=torch.float16,
64 | device="cuda"
65 | ):
66 |
67 | # 224*224
68 | # image_processor do not used in opt
69 | image_processor = CLIPImageProcessor.from_pretrained('/cache/vit-large-patch14')
70 | # 1024*1024
71 |
72 | image_processor_high = train_transform
73 |
74 | image_token_len = 256
75 |
76 | self.config.vision_tower = vision_tower
77 | self.config.image_token_len = image_token_len
78 | self.config.use_im_start_end = True
79 |
80 | self.config.vision_select_layer = vision_select_layer
81 | self.config.freeze_vision_tower = freeze_vision_tower
82 |
83 | return dict(
84 | image_processor=image_processor,
85 | image_processor_high=image_processor_high,
86 | image_token_len=image_token_len,
87 | )
88 |
89 | def embed_tokens(self, x):
90 | return self.get_input_embeddings()(x)
91 |
92 | def forward(
93 | self,
94 | input_ids: torch.LongTensor = None,
95 | attention_mask: Optional[torch.Tensor] = None,
96 | past_key_values: Optional[List[torch.FloatTensor]] = None,
97 | inputs_embeds: Optional[torch.FloatTensor] = None,
98 | use_cache: Optional[bool] = None,
99 | output_attentions: Optional[bool] = None,
100 | output_hidden_states: Optional[bool] = None,
101 | images: Optional[torch.FloatTensor] = None,
102 | return_dict: Optional[bool] = None,
103 | ) -> Union[Tuple, BaseModelOutputWithPast]:
104 |
105 | # HACK: replace back original embeddings for LLaVA pretraining
106 | # orig_embeds_params = getattr(self, 'orig_embeds_params', None)
107 | # if orig_embeds_params is not None:
108 | # with torch.no_grad():
109 | # self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
110 |
111 | if inputs_embeds is None:
112 | inputs_embeds = self.embed_tokens(input_ids)
113 | # inputs_embeds = self.wte(input_ids)
114 |
115 |
116 | vision_tower = getattr(self, 'vision_tower', None)
117 |
118 |
119 | if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
120 |
121 | use_im_start_end = getattr(self.config, "use_im_start_end", -1)
122 |
123 | vision_select_layer = getattr(self.config, "vision_select_layer", -1)
124 | im_patch_token = getattr(self.config, "im_patch_token", -1)
125 | im_start_token = getattr(self.config, "im_start_token", -1)
126 | im_end_token = getattr(self.config, "im_end_token", -1)
127 | freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
128 |
129 |
130 | image_features = []
131 | for image in images:
132 |
133 | with torch.set_grad_enabled(True):
134 | cnn_feature = vision_tower(image[1])
135 | cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1)
136 | image_feature_final = cnn_feature
137 |
138 | image_features.append(image_feature_final)
139 |
140 | if type(images) is list:
141 | image_features = [self.mm_projector(image_feature) for image_feature in image_features]
142 | else:
143 | # image_features = self.mm_projector(image_features)
144 | raise NotImplementedError
145 |
146 | # dummy_image_features = torch.zeros(1024, 1664, device=inputs_embeds.device, dtype=inputs_embeds.dtype).permute(0, 2, 1).reshape(dummy_image_features.shape[0], -1, 32, 32)
147 | # VIT 1024; CNN:1024
148 | dummy_image_features = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
149 | dummy_image_features = self.mm_projector(dummy_image_features)
150 |
151 | use_im_start_end = True
152 | new_input_embeds = []
153 | for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
154 | if (cur_input_ids == im_patch_token).sum() == 0:
155 | # multimodal LLM, but the current sample is not multimodal
156 | cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
157 | new_input_embeds.append(cur_input_embeds)
158 | continue
159 |
160 | if use_im_start_end:
161 | if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
162 | raise ValueError("The number of image start tokens and image end tokens should be the same.")
163 |
164 | image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
165 | for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
166 | per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
167 | num_patches = per_cur_image_features.shape[0]
168 | # print(cur_input_ids)
169 | if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
170 | raise ValueError("The image end token should follow the image start token.")
171 |
172 | cur_input_embeds = torch.cat(
173 | (
174 | cur_input_embeds[:image_start_token_pos+1],
175 | per_cur_image_features,
176 | cur_input_embeds[image_start_token_pos + num_patches + 1:]
177 | ),
178 | dim=0
179 | )
180 |
181 |
182 | new_input_embeds.append(cur_input_embeds)
183 | else:
184 | raise NotImplementedError
185 |
186 | inputs_embeds = torch.stack(new_input_embeds, dim=0)
187 |
188 |
189 | return super(varyOPTModel, self).forward(
190 | input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
191 | inputs_embeds=inputs_embeds, use_cache=use_cache,
192 | output_attentions=output_attentions, output_hidden_states=output_hidden_states,
193 | return_dict=return_dict
194 | )
195 |
196 |
197 | class varyOPTForCausalLM(OPTForCausalLM):
198 | config_class = varyConfig
199 | # supports_gradient_checkpointing = True
200 |
201 | def __init__(self, config):
202 | super(OPTForCausalLM, self).__init__(config)
203 | self.model = varyOPTModel(config)
204 |
205 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
206 |
207 | # Initialize weights and apply final processing
208 | self.post_init()
209 |
210 | def get_model(self):
211 | return self.model
212 |
213 | # def _set_gradient_checkpointing(self, module, value=False):
214 | # if isinstance(module, varyQwenModel):
215 | # module.gradient_checkpointing = value
216 |
217 |
218 |
219 | def forward(
220 | self,
221 | input_ids: Optional[torch.LongTensor] = None,
222 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
223 | attention_mask: Optional[torch.FloatTensor] = None,
224 | token_type_ids: Optional[torch.LongTensor] = None,
225 | position_ids: Optional[torch.LongTensor] = None,
226 | head_mask: Optional[torch.FloatTensor] = None,
227 | inputs_embeds: Optional[torch.FloatTensor] = None,
228 | encoder_hidden_states: Optional[torch.Tensor] = None,
229 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
230 | labels: Optional[torch.LongTensor] = None,
231 | use_cache: Optional[bool] = None,
232 | output_attentions: Optional[bool] = None,
233 | output_hidden_states: Optional[bool] = None,
234 | images: Optional[torch.FloatTensor] = None,
235 | return_dict: Optional[bool] = None,
236 |
237 | ) -> Union[Tuple, CausalLMOutputWithPast]:
238 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
239 | output_hidden_states = (
240 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
241 | )
242 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
243 |
244 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
245 | # print(input_ids)
246 | # print(len(images))
247 |
248 | outputs = self.model(
249 | input_ids=input_ids,
250 | past_key_values=past_key_values,
251 | attention_mask=attention_mask,
252 | inputs_embeds=inputs_embeds,
253 | use_cache=use_cache,
254 | output_attentions=output_attentions,
255 | output_hidden_states=output_hidden_states,
256 | images=images,
257 | return_dict=return_dict
258 |
259 | )
260 |
261 |
262 |
263 |
264 | hidden_states = outputs[0]
265 | logits = self.lm_head(hidden_states).contiguous()
266 |
267 | # logits
268 |
269 | loss = None
270 | if labels is not None:
271 | # move labels to correct device to enable model parallelism
272 | labels = labels.to(logits.device)
273 | # Shift so that tokens < n predict n
274 | shift_logits = logits[..., :-1, :].contiguous()
275 | shift_labels = labels[..., 1:].contiguous()
276 | # Flatten the tokens
277 | loss_fct = CrossEntropyLoss()
278 | loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
279 |
280 | if not return_dict:
281 | output = (logits,) + outputs[1:]
282 | return (loss,) + output if loss is not None else output
283 |
284 | return CausalLMOutputWithPast(
285 | loss=loss,
286 | logits=logits,
287 | past_key_values=outputs.past_key_values,
288 | hidden_states=outputs.hidden_states,
289 | attentions=outputs.attentions,
290 | )
291 |
292 |
293 | def prepare_inputs_for_generation(
294 | self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
295 | ):
296 | token_type_ids = kwargs.get("token_type_ids", None)
297 | if past_key_values:
298 | input_ids = input_ids[:, -1].unsqueeze(-1)
299 | if token_type_ids is not None:
300 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
301 |
302 | attention_mask = kwargs.get("attention_mask", None)
303 | position_ids = kwargs.get("position_ids", None)
304 |
305 | if attention_mask is not None and position_ids is None:
306 | position_ids = attention_mask.long().cumsum(-1) - 1
307 | position_ids.masked_fill_(attention_mask == 0, 1)
308 | if past_key_values:
309 | position_ids = position_ids[:, -1].unsqueeze(-1)
310 | else:
311 | position_ids = None
312 |
313 | if inputs_embeds is not None and past_key_values is None:
314 | model_inputs = {"inputs_embeds": inputs_embeds}
315 | else:
316 | model_inputs = {"input_ids": input_ids}
317 |
318 | model_inputs.update(
319 | {
320 | "past_key_values": past_key_values,
321 | "use_cache": kwargs.get("use_cache"),
322 | "position_ids": position_ids,
323 | "attention_mask": attention_mask,
324 | "token_type_ids": token_type_ids,
325 | "images": kwargs.get("images", None),
326 | }
327 | )
328 | return model_inputs
329 |
330 | def initialize_vision_tokenizer(
331 | self,
332 | tokenizer,
333 | freeze_lm_model=False,
334 | pretrained_stage1_model=None,
335 | device="cuda"
336 | ):
337 | config = self.get_model().config
338 |
339 | # add image patch token
340 | tokenizer.add_tokens("", special_tokens=True)
341 | self.resize_token_embeddings(len(tokenizer))
342 |
343 | tokenizer.add_tokens(DEFAULT_IMAGE_PATCH_TOKEN, special_tokens=True)
344 | self.resize_token_embeddings(len(tokenizer))
345 | config.im_patch_token = tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_PATCH_TOKEN)
346 |
347 |
348 | config.use_im_start_end = True
349 |
350 | # add image start token and end token
351 | if config.use_im_start_end:
352 | num_new_tokens = 2
353 | tokenizer.add_tokens(DEFAULT_IM_START_TOKEN , special_tokens=True)
354 | tokenizer.add_tokens(DEFAULT_IM_END_TOKEN , special_tokens=True)
355 | self.resize_token_embeddings(len(tokenizer))
356 | config.im_start_token = tokenizer.convert_tokens_to_ids(DEFAULT_IM_START_TOKEN)
357 | config.im_end_token = tokenizer.convert_tokens_to_ids(DEFAULT_IM_END_TOKEN)
358 |
359 | # config.im_start_token, config.im_end_token = 151857, 151858
360 |
361 | if num_new_tokens > 0:
362 | input_embeddings = self.get_input_embeddings().weight.data
363 | output_embeddings = self.get_output_embeddings().weight.data
364 |
365 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
366 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
367 |
368 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
369 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
370 |
371 |
372 |
373 | AutoConfig.register("vary", varyConfig)
374 | AutoModelForCausalLM.register(varyConfig, varyOPTForCausalLM)
375 |
--------------------------------------------------------------------------------
/Vary-master/vary/model/vary_qwen_vary.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import List, Optional, Tuple, Union
17 |
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | from torch.nn import CrossEntropyLoss
22 |
23 | from transformers import AutoConfig, AutoModelForCausalLM, \
24 | CLIPVisionModel, CLIPImageProcessor
25 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
26 |
27 | from vary.utils.constants import *
28 |
29 | from vary.model.plug.blip_process import BlipImageEvalProcessor
30 |
31 | from vary.model.llm.qwen.modeling_qwen import QWenLMHeadModel, QWenModel
32 |
33 | from vary.model.llm.qwen.configuration_qwen import QWenConfig
34 | from vary.model.vision_encoder.sam import build_sam_vit_b
35 | from vary.model.plug.transforms import train_transform, test_transform
36 |
37 |
38 | class varyConfig(QWenConfig):
39 | model_type = "vary"
40 |
41 |
42 | class varyQwenModel(QWenModel):
43 | config_class = varyConfig
44 |
45 | def __init__(self, config: QWenConfig):
46 | super(varyQwenModel, self).__init__(config)
47 | # TODO download the clip-vit in huggingface
48 | self.vision_tower = CLIPVisionModel.from_pretrained('/cache/vit-large-patch14/')
49 |
50 | self.vision_tower_high = build_sam_vit_b() # build_sam_vit_b(checkpoint = 'xxxx') for train
51 |
52 | self.mm_projector = nn.Linear(1024, 2048)
53 | self.mm_projector_vary = nn.Linear(1024, 2048)
54 |
55 | def initialize_vision_modules(
56 | self,
57 | vision_tower,
58 | pretrained_stage1_model=None,
59 | freeze_vision_tower=False,
60 | use_im_start_end=False,
61 | vision_select_layer=-1,
62 | dtype=torch.float16,
63 | device="cuda"
64 | ):
65 |
66 | # 224*224
67 | # TODO download the clip-vit in huggingface
68 | image_processor = CLIPImageProcessor.from_pretrained('/cache/vit-large-patch14/')
69 | # 1024*1024
70 | image_processor_high = train_transform
71 |
72 | self.vision_tower = self.vision_tower.to(dtype=dtype, device=device)
73 |
74 | self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
75 |
76 | self.mm_projector = self.mm_projector.to(dtype=dtype, device=device)
77 | self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
78 |
79 | image_token_len = 256
80 |
81 | self.config.vision_tower = vision_tower
82 | self.config.image_token_len = image_token_len
83 |
84 | self.config.use_im_start_end = True
85 |
86 | self.config.vision_select_layer = vision_select_layer
87 | self.config.freeze_vision_tower = freeze_vision_tower
88 |
89 | return dict(
90 | image_processor=image_processor,
91 | image_processor_high=image_processor_high,
92 | image_token_len=image_token_len,
93 |
94 | )
95 |
96 | def embed_tokens(self, x):
97 | return self.wte(x)
98 |
99 | def forward(
100 | self,
101 | input_ids: torch.LongTensor = None,
102 | attention_mask: Optional[torch.Tensor] = None,
103 | past_key_values: Optional[List[torch.FloatTensor]] = None,
104 | inputs_embeds: Optional[torch.FloatTensor] = None,
105 | use_cache: Optional[bool] = None,
106 | output_attentions: Optional[bool] = None,
107 | output_hidden_states: Optional[bool] = None,
108 | images: Optional[torch.FloatTensor] = None,
109 | return_dict: Optional[bool] = None,
110 | ) -> Union[Tuple, BaseModelOutputWithPast]:
111 |
112 | # HACK: replace back original embeddings for LLaVA pretraining
113 | # orig_embeds_params = getattr(self, 'orig_embeds_params', None)
114 | # if orig_embeds_params is not None:
115 | # with torch.no_grad():
116 | # self.get_input_embeddings().weight[:-self.num_new_tokens] = orig_embeds_params[:-self.num_new_tokens].data
117 |
118 | if inputs_embeds is None:
119 | inputs_embeds = self.embed_tokens(input_ids)
120 | # inputs_embeds = self.wte(input_ids)
121 |
122 |
123 | vision_tower = getattr(self, 'vision_tower', None)
124 | vision_tower_high = getattr(self, 'vision_tower_high', None)
125 |
126 |
127 | if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
128 |
129 | use_im_start_end = getattr(self.config, "use_im_start_end", -1)
130 |
131 | vision_select_layer = getattr(self.config, "vision_select_layer", -1)
132 | # im_patch_token = getattr(self.config, "im_patch_token", -1)
133 | # im_start_token = getattr(self.config, "im_start_token", -1)
134 | # im_end_token = getattr(self.config, "im_end_token", -1)
135 | # freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
136 |
137 | im_patch_token = 151859
138 |
139 | im_start_token = 151857
140 |
141 | im_end_token = 151858
142 |
143 |
144 | image_features_1 = []
145 | image_features_2 = []
146 | for image in images:
147 |
148 | with torch.set_grad_enabled(False):
149 | image_forward_out = vision_tower(image[0], output_hidden_states=True)
150 | select_hidden_state = image_forward_out.hidden_states[vision_select_layer]
151 | image_feature = select_hidden_state[:, 1:] # 256*1024
152 | with torch.set_grad_enabled(False):
153 | cnn_feature = vision_tower_high(image[1])
154 | cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
155 |
156 | image_features_1.append(image_feature)
157 | image_features_2.append(cnn_feature)
158 |
159 |
160 | if type(images) is list:
161 | image_features_1 = [self.mm_projector(image_feature) for image_feature in image_features_1]
162 | image_features_2 = [self.mm_projector_vary(image_feature) for image_feature in image_features_2]
163 | image_features = [torch.cat((image_feature[0], image_feature[1]), dim=-1) for image_feature in zip(image_features_1, image_features_2)]
164 | else:
165 |
166 | raise NotImplementedError
167 |
168 |
169 | # dummy_image_features = torch.zeros(256, 4096, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
170 | dummy_image_features_1 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
171 | dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
172 | dummy_image_features_1 = self.mm_projector(dummy_image_features_1)
173 | dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2)
174 | dummy_image_features = torch.cat((dummy_image_features_1, dummy_image_features_2), dim=-1)
175 | use_im_start_end = True
176 | new_input_embeds = []
177 | for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
178 | if (cur_input_ids == im_patch_token).sum() == 0:
179 | # multimodal LLM, but the current sample is not multimodal
180 | cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
181 | new_input_embeds.append(cur_input_embeds)
182 | continue
183 |
184 | if use_im_start_end:
185 | if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
186 | raise ValueError("The number of image start tokens and image end tokens should be the same.")
187 |
188 | image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
189 | for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
190 | per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
191 | num_patches = per_cur_image_features.shape[0]
192 |
193 | if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
194 | raise ValueError("The image end token should follow the image start token.")
195 |
196 | # if orig_embeds_params is not None:
197 | # cur_new_input_embeds = torch.cat(
198 | # (
199 | # cur_input_embeds[:image_start_token_pos].detach(),
200 | # cur_input_embeds[image_start_token_pos:image_start_token_pos+1],
201 | # per_cur_image_features,
202 | # cur_input_embeds[image_start_token_pos + num_patches + 1:image_start_token_pos + num_patches + 2],
203 | # cur_input_embeds[image_start_token_pos + num_patches + 2:].detach()
204 | # ),
205 | # dim=0
206 | # )
207 | # else:
208 | cur_input_embeds = torch.cat(
209 | (
210 | cur_input_embeds[:image_start_token_pos+1],
211 | per_cur_image_features,
212 | cur_input_embeds[image_start_token_pos + num_patches + 1:]
213 | ),
214 | dim=0
215 | )
216 |
217 |
218 | new_input_embeds.append(cur_input_embeds)
219 | else:
220 | raise NotImplementedError
221 |
222 | inputs_embeds = torch.stack(new_input_embeds, dim=0)
223 |
224 | return super(varyQwenModel, self).forward(
225 | input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
226 | inputs_embeds=inputs_embeds, use_cache=use_cache,
227 | output_attentions=output_attentions, output_hidden_states=output_hidden_states,
228 | return_dict=return_dict
229 | )
230 |
231 |
232 | class varyQwenForCausalLM(QWenLMHeadModel):
233 | config_class = varyConfig
234 | # supports_gradient_checkpointing = True
235 |
236 | def __init__(self, config):
237 | super(QWenLMHeadModel, self).__init__(config)
238 | self.transformer = varyQwenModel(config)
239 |
240 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
241 |
242 | # Initialize weights and apply final processing
243 | self.post_init()
244 |
245 | def get_model(self):
246 | return self.transformer
247 |
248 | # def _set_gradient_checkpointing(self, module, value=False):
249 | # if isinstance(module, varyQwenModel):
250 | # module.gradient_checkpointing = value
251 |
252 | def forward(
253 | self,
254 | input_ids: Optional[torch.LongTensor] = None,
255 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
256 | attention_mask: Optional[torch.FloatTensor] = None,
257 | token_type_ids: Optional[torch.LongTensor] = None,
258 | position_ids: Optional[torch.LongTensor] = None,
259 | head_mask: Optional[torch.FloatTensor] = None,
260 | inputs_embeds: Optional[torch.FloatTensor] = None,
261 | encoder_hidden_states: Optional[torch.Tensor] = None,
262 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
263 | labels: Optional[torch.LongTensor] = None,
264 | use_cache: Optional[bool] = None,
265 | output_attentions: Optional[bool] = None,
266 | output_hidden_states: Optional[bool] = None,
267 | images: Optional[torch.FloatTensor] = None,
268 | return_dict: Optional[bool] = None,
269 |
270 | ) -> Union[Tuple, CausalLMOutputWithPast]:
271 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
272 | output_hidden_states = (
273 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
274 | )
275 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
276 |
277 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
278 |
279 |
280 | transformer_outputs = self.transformer(
281 | input_ids=input_ids,
282 | past_key_values=past_key_values,
283 | attention_mask=attention_mask,
284 | inputs_embeds=inputs_embeds,
285 | use_cache=use_cache,
286 | output_attentions=output_attentions,
287 | output_hidden_states=output_hidden_states,
288 | images=images,
289 | return_dict=return_dict
290 |
291 | )
292 |
293 |
294 |
295 | hidden_states = transformer_outputs[0]
296 | lm_logits = self.lm_head(hidden_states)
297 |
298 | # logits
299 |
300 | loss = None
301 | if labels is not None:
302 | labels = labels.to(lm_logits.device)
303 | shift_logits = lm_logits[..., :-1, :].contiguous()
304 | shift_labels = labels[..., 1:].contiguous()
305 | loss_fct = CrossEntropyLoss()
306 | loss = loss_fct(
307 | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
308 | )
309 |
310 | if not return_dict:
311 | output = (lm_logits,) + transformer_outputs[1:]
312 | return ((loss,) + output) if loss is not None else output
313 |
314 | # print(loss)
315 |
316 | if not return_dict:
317 | output = (lm_logits,) + transformer_outputs[1:]
318 | return ((loss,) + output) if loss is not None else output
319 |
320 | return CausalLMOutputWithPast(
321 | loss=loss,
322 | logits=lm_logits,
323 | past_key_values=transformer_outputs.past_key_values,
324 | hidden_states=transformer_outputs.hidden_states,
325 | attentions=transformer_outputs.attentions,
326 | )
327 |
328 | def prepare_inputs_for_generation(
329 | self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
330 | ):
331 | token_type_ids = kwargs.get("token_type_ids", None)
332 | if past_key_values:
333 | input_ids = input_ids[:, -1].unsqueeze(-1)
334 | if token_type_ids is not None:
335 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
336 |
337 | attention_mask = kwargs.get("attention_mask", None)
338 | position_ids = kwargs.get("position_ids", None)
339 |
340 | if attention_mask is not None and position_ids is None:
341 | position_ids = attention_mask.long().cumsum(-1) - 1
342 | position_ids.masked_fill_(attention_mask == 0, 1)
343 | if past_key_values:
344 | position_ids = position_ids[:, -1].unsqueeze(-1)
345 | else:
346 | position_ids = None
347 |
348 | if inputs_embeds is not None and past_key_values is None:
349 | model_inputs = {"inputs_embeds": inputs_embeds}
350 | else:
351 | model_inputs = {"input_ids": input_ids}
352 |
353 | model_inputs.update(
354 | {
355 | "past_key_values": past_key_values,
356 | "use_cache": kwargs.get("use_cache"),
357 | "position_ids": position_ids,
358 | "attention_mask": attention_mask,
359 | "token_type_ids": token_type_ids,
360 | "images": kwargs.get("images", None),
361 | }
362 | )
363 | return model_inputs
364 |
365 | def initialize_vision_tokenizer(
366 | self,
367 | tokenizer,
368 | freeze_lm_model=False,
369 | pretrained_stage1_model=None,
370 | device="cuda"
371 | ):
372 | config = self.get_model().config
373 |
374 |
375 | self.resize_token_embeddings(len(tokenizer))
376 |
377 |
378 | config.im_patch_token = 151859
379 |
380 | config.use_im_start_end = True
381 |
382 | # add image start token and end token
383 | if config.use_im_start_end:
384 | # num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
385 | self.resize_token_embeddings(len(tokenizer))
386 | # config.im_start_token, config.im_end_token = tokenizer.convert_tokens_to_ids([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN])
387 |
388 | config.im_start_token, config.im_end_token = 151857, 151858
389 |
390 |
391 | AutoConfig.register("vary", varyConfig)
392 | AutoModelForCausalLM.register(varyConfig, varyQwenForCausalLM)
393 |
--------------------------------------------------------------------------------
/Vary-master/vary/model/vary_toy_qwen1_8.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.nn import CrossEntropyLoss
6 | from transformers import AutoConfig, AutoModelForCausalLM, \
7 | CLIPVisionModel, CLIPImageProcessor
8 | from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
9 | from vary.utils.constants import *
10 | from vary.model.plug.blip_process import BlipImageEvalProcessor
11 | from vary.model.llm.qwen.modeling_qwen import QWenLMHeadModel, QWenModel
12 | from vary.model.llm.qwen.configuration_qwen import QWenConfig
13 | from vary.model.vision_encoder.sam import build_sam_vit_b
14 |
15 |
16 | class varyConfig(QWenConfig):
17 | model_type = "vary"
18 |
19 |
20 | class varyQwenModel(QWenModel):
21 | config_class = varyConfig
22 |
23 | def __init__(self, config: QWenConfig):
24 | super(varyQwenModel, self).__init__(config)
25 |
26 | self.vision_tower = CLIPVisionModel.from_pretrained('/data/hypertext/ucaswei/cache/vit-large-patch14/vit-large-patch14/')
27 | self.vision_tower_high = build_sam_vit_b()
28 |
29 | self.mm_projector = nn.Linear(1024, 1024)
30 | self.mm_projector_vary = nn.Linear(1024, 1024)
31 |
32 | def initialize_vision_modules(
33 | self,
34 | vision_tower,
35 | pretrained_stage1_model=None,
36 | freeze_vision_tower=False,
37 | use_im_start_end=False,
38 | vision_select_layer=-1,
39 | dtype=torch.float16,
40 | device="cuda"
41 | ):
42 |
43 | # 224*224
44 | image_processor = CLIPImageProcessor.from_pretrained('/data/hypertext/ucaswei/cache/vit-large-patch14/vit-large-patch14/')
45 | # 1024*1024
46 | image_processor_high = BlipImageEvalProcessor(image_size=1024)
47 |
48 | self.vision_tower = self.vision_tower.to(dtype=dtype, device=device)
49 | self.vision_tower_high = self.vision_tower_high.to(dtype=dtype, device=device)
50 |
51 | self.mm_projector = self.mm_projector.to(dtype=dtype, device=device)
52 | self.mm_projector_vary = self.mm_projector_vary.to(dtype=dtype, device=device)
53 |
54 | image_token_len = 256
55 |
56 | self.config.vision_tower = vision_tower
57 | self.config.image_token_len = image_token_len
58 | self.config.use_im_start_end = True
59 | self.config.vision_select_layer = vision_select_layer
60 | self.config.freeze_vision_tower = freeze_vision_tower
61 |
62 | return dict(
63 | image_processor=image_processor,
64 | image_processor_high=image_processor_high,
65 | image_token_len=image_token_len,
66 | # vision_config=vision_config
67 | )
68 |
69 | def embed_tokens(self, x):
70 | return self.wte(x)
71 |
72 | def forward(
73 | self,
74 | input_ids: torch.LongTensor = None,
75 | attention_mask: Optional[torch.Tensor] = None,
76 | past_key_values: Optional[List[torch.FloatTensor]] = None,
77 | inputs_embeds: Optional[torch.FloatTensor] = None,
78 | use_cache: Optional[bool] = None,
79 | output_attentions: Optional[bool] = None,
80 | output_hidden_states: Optional[bool] = None,
81 | images: Optional[torch.FloatTensor] = None,
82 | return_dict: Optional[bool] = None,
83 | ) -> Union[Tuple, BaseModelOutputWithPast]:
84 |
85 |
86 |
87 | if inputs_embeds is None:
88 | inputs_embeds = self.embed_tokens(input_ids)
89 |
90 |
91 | vision_tower = getattr(self, 'vision_tower', None)
92 | vision_tower_high = getattr(self, 'vision_tower_high', None)
93 |
94 |
95 | if vision_tower is not None and (input_ids.shape[1] != 1 or self.training) and images is not None:
96 |
97 | use_im_start_end = getattr(self.config, "use_im_start_end", -1)
98 |
99 | vision_select_layer = getattr(self.config, "vision_select_layer", -1)
100 | im_patch_token = getattr(self.config, "im_patch_token", -1)
101 | im_start_token = getattr(self.config, "im_start_token", -1)
102 | im_end_token = getattr(self.config, "im_end_token", -1)
103 | freeze_vision_tower = getattr(self.config, "freeze_vision_tower", False)
104 |
105 | im_patch_token = 151859
106 | im_start_token = 151857
107 | im_end_token = 151858
108 |
109 | image_features = []
110 | image_features_1 = []
111 | image_features_2 = []
112 | for image in images:
113 | with torch.set_grad_enabled(False):
114 | image_forward_out = vision_tower(image[0], output_hidden_states=True)
115 | select_hidden_state = image_forward_out.hidden_states[vision_select_layer]
116 | image_feature = select_hidden_state[:, 1:] # 256*1024
117 | with torch.set_grad_enabled(False):
118 | cnn_feature = vision_tower_high(image[1])
119 | cnn_feature = cnn_feature.flatten(2).permute(0, 2, 1) # 256*1024
120 | image_features_1.append(image_feature)
121 | image_features_2.append(cnn_feature)
122 |
123 |
124 | if type(images) is list:
125 | image_features_1 = [self.mm_projector(image_feature) for image_feature in image_features_1]
126 | image_features_2 = [self.mm_projector_vary(image_feature) for image_feature in image_features_2]
127 | image_features = [torch.cat((image_feature[0], image_feature[1]), dim=-1) for image_feature in zip(image_features_1, image_features_2)]
128 | else:
129 | raise NotImplementedError
130 |
131 | dummy_image_features_1 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
132 | dummy_image_features_2 = torch.zeros(256, 1024, device=inputs_embeds.device, dtype=inputs_embeds.dtype)
133 | dummy_image_features_1 = self.mm_projector(dummy_image_features_1)
134 | dummy_image_features_2 = self.mm_projector_vary(dummy_image_features_2)
135 | dummy_image_features = torch.cat((dummy_image_features_1, dummy_image_features_2), dim=-1)
136 | use_im_start_end = True
137 | new_input_embeds = []
138 | for cur_input_ids, cur_input_embeds, cur_image_features in zip(input_ids, inputs_embeds, image_features):
139 | if (cur_input_ids == im_patch_token).sum() == 0:
140 | # multimodal LLM, but the current sample is not multimodal
141 | cur_input_embeds = cur_input_embeds + (0. * dummy_image_features).sum()
142 | new_input_embeds.append(cur_input_embeds)
143 | continue
144 |
145 | if use_im_start_end:
146 | if (cur_input_ids == im_start_token).sum() != (cur_input_ids == im_end_token).sum():
147 | raise ValueError("The number of image start tokens and image end tokens should be the same.")
148 |
149 | image_start_tokens = torch.where(cur_input_ids == im_start_token)[0]
150 | for image_start_token_pos, per_cur_image_features in zip(image_start_tokens, cur_image_features):
151 | per_cur_image_features = per_cur_image_features.to(device=cur_input_embeds.device)
152 | num_patches = per_cur_image_features.shape[0]
153 |
154 | if cur_input_ids[image_start_token_pos + num_patches + 1] != im_end_token:
155 | raise ValueError("The image end token should follow the image start token.")
156 |
157 | cur_input_embeds = torch.cat(
158 | (
159 | cur_input_embeds[:image_start_token_pos+1],
160 | per_cur_image_features,
161 | cur_input_embeds[image_start_token_pos + num_patches + 1:]
162 | ),
163 | dim=0
164 | )
165 |
166 |
167 | new_input_embeds.append(cur_input_embeds)
168 | else:
169 | raise NotImplementedError
170 |
171 | inputs_embeds = torch.stack(new_input_embeds, dim=0)
172 |
173 | return super(varyQwenModel, self).forward(
174 | input_ids=None, attention_mask=attention_mask, past_key_values=past_key_values,
175 | inputs_embeds=inputs_embeds, use_cache=use_cache,
176 | output_attentions=output_attentions, output_hidden_states=output_hidden_states,
177 | return_dict=return_dict
178 | )
179 |
180 |
181 | class varyQwenForCausalLM(QWenLMHeadModel):
182 | config_class = varyConfig
183 | # supports_gradient_checkpointing = True
184 |
185 | def __init__(self, config):
186 | super(QWenLMHeadModel, self).__init__(config)
187 | self.transformer = varyQwenModel(config)
188 |
189 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
190 |
191 | # Initialize weights and apply final processing
192 | self.post_init()
193 |
194 | def get_model(self):
195 | return self.transformer
196 |
197 |
198 | def forward(
199 | self,
200 | input_ids: Optional[torch.LongTensor] = None,
201 | past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
202 | attention_mask: Optional[torch.FloatTensor] = None,
203 | token_type_ids: Optional[torch.LongTensor] = None,
204 | position_ids: Optional[torch.LongTensor] = None,
205 | head_mask: Optional[torch.FloatTensor] = None,
206 | inputs_embeds: Optional[torch.FloatTensor] = None,
207 | encoder_hidden_states: Optional[torch.Tensor] = None,
208 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
209 | labels: Optional[torch.LongTensor] = None,
210 | use_cache: Optional[bool] = None,
211 | output_attentions: Optional[bool] = None,
212 | output_hidden_states: Optional[bool] = None,
213 | images: Optional[torch.FloatTensor] = None,
214 | return_dict: Optional[bool] = None,
215 |
216 | ) -> Union[Tuple, CausalLMOutputWithPast]:
217 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
218 | output_hidden_states = (
219 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
220 | )
221 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
222 |
223 |
224 | transformer_outputs = self.transformer(
225 | input_ids=input_ids,
226 | past_key_values=past_key_values,
227 | attention_mask=attention_mask,
228 | inputs_embeds=inputs_embeds,
229 | use_cache=use_cache,
230 | output_attentions=output_attentions,
231 | output_hidden_states=output_hidden_states,
232 | images=images,
233 | return_dict=return_dict
234 |
235 | )
236 |
237 | hidden_states = transformer_outputs[0]
238 | lm_logits = self.lm_head(hidden_states)
239 |
240 | # logits
241 |
242 | loss = None
243 | if labels is not None:
244 | labels = labels.to(lm_logits.device)
245 | shift_logits = lm_logits[..., :-1, :].contiguous()
246 | shift_labels = labels[..., 1:].contiguous()
247 | loss_fct = CrossEntropyLoss()
248 | loss = loss_fct(
249 | shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
250 | )
251 |
252 | if not return_dict:
253 | output = (lm_logits,) + transformer_outputs[1:]
254 | return ((loss,) + output) if loss is not None else output
255 |
256 | if not return_dict:
257 | output = (lm_logits,) + transformer_outputs[1:]
258 | return ((loss,) + output) if loss is not None else output
259 |
260 | return CausalLMOutputWithPast(
261 | loss=loss,
262 | logits=lm_logits,
263 | past_key_values=transformer_outputs.past_key_values,
264 | hidden_states=transformer_outputs.hidden_states,
265 | attentions=transformer_outputs.attentions,
266 | )
267 |
268 | def prepare_inputs_for_generation(
269 | self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
270 | ):
271 | token_type_ids = kwargs.get("token_type_ids", None)
272 | if past_key_values:
273 | input_ids = input_ids[:, -1].unsqueeze(-1)
274 | if token_type_ids is not None:
275 | token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
276 |
277 | attention_mask = kwargs.get("attention_mask", None)
278 | position_ids = kwargs.get("position_ids", None)
279 |
280 | if attention_mask is not None and position_ids is None:
281 | position_ids = attention_mask.long().cumsum(-1) - 1
282 | position_ids.masked_fill_(attention_mask == 0, 1)
283 | if past_key_values:
284 | position_ids = position_ids[:, -1].unsqueeze(-1)
285 | else:
286 | position_ids = None
287 |
288 | if inputs_embeds is not None and past_key_values is None:
289 | model_inputs = {"inputs_embeds": inputs_embeds}
290 | else:
291 | model_inputs = {"input_ids": input_ids}
292 |
293 | model_inputs.update(
294 | {
295 | "past_key_values": past_key_values,
296 | "use_cache": kwargs.get("use_cache"),
297 | "position_ids": position_ids,
298 | "attention_mask": attention_mask,
299 | "token_type_ids": token_type_ids,
300 | "images": kwargs.get("images", None),
301 | }
302 | )
303 | return model_inputs
304 |
305 | def initialize_vision_tokenizer(
306 | self,
307 | tokenizer,
308 | freeze_lm_model=False,
309 | pretrained_stage1_model=None,
310 | device="cuda"
311 | ):
312 | config = self.get_model().config
313 |
314 | self.resize_token_embeddings(len(tokenizer))
315 |
316 | config.im_patch_token = 151859
317 |
318 | config.use_im_start_end = True
319 |
320 | if config.use_im_start_end:
321 | self.resize_token_embeddings(len(tokenizer))
322 |
323 | config.im_start_token, config.im_end_token = 151857, 151858
324 |
325 |
326 |
327 | AutoConfig.register("vary", varyConfig)
328 | AutoModelForCausalLM.register(varyConfig, varyQwenForCausalLM)
329 |
--------------------------------------------------------------------------------
/Vary-master/vary/model/vision_encoder/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/Vary-master/vary/model/vision_encoder/sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 |
11 | from typing import Optional, Tuple, Type
12 |
13 | from functools import partial
14 |
15 | import torch
16 | import torch.nn as nn
17 |
18 | from typing import Type
19 |
20 | import math
21 |
22 |
23 |
24 |
25 | class MLPBlock(nn.Module):
26 | def __init__(
27 | self,
28 | embedding_dim: int,
29 | mlp_dim: int,
30 | act: Type[nn.Module] = nn.GELU,
31 | ) -> None:
32 | super().__init__()
33 | self.lin1 = nn.Linear(embedding_dim, mlp_dim)
34 | self.lin2 = nn.Linear(mlp_dim, embedding_dim)
35 | self.act = act()
36 |
37 | def forward(self, x: torch.Tensor) -> torch.Tensor:
38 | return self.lin2(self.act(self.lin1(x)))
39 |
40 |
41 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
42 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
43 | class LayerNorm2d(nn.Module):
44 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
45 | super().__init__()
46 | self.weight = nn.Parameter(torch.ones(num_channels))
47 | self.bias = nn.Parameter(torch.zeros(num_channels))
48 | self.eps = eps
49 |
50 | def forward(self, x: torch.Tensor) -> torch.Tensor:
51 | u = x.mean(1, keepdim=True)
52 | s = (x - u).pow(2).mean(1, keepdim=True)
53 | x = (x - u) / torch.sqrt(s + self.eps)
54 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
55 | return x
56 |
57 |
58 | # This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
59 | class ImageEncoderViT(nn.Module):
60 | def __init__(
61 | self,
62 | img_size: int = 1024,
63 | patch_size: int = 16,
64 | in_chans: int = 3,
65 | embed_dim: int = 768,
66 | depth: int = 12,
67 | num_heads: int = 12,
68 | mlp_ratio: float = 4.0,
69 | out_chans: int = 256,
70 | qkv_bias: bool = True,
71 | norm_layer: Type[nn.Module] = nn.LayerNorm,
72 | act_layer: Type[nn.Module] = nn.GELU,
73 | use_abs_pos: bool = True,
74 | use_rel_pos: bool = False,
75 | rel_pos_zero_init: bool = True,
76 | window_size: int = 0,
77 | global_attn_indexes: Tuple[int, ...] = (),
78 | ) -> None:
79 | """
80 | Args:
81 | img_size (int): Input image size.
82 | patch_size (int): Patch size.
83 | in_chans (int): Number of input image channels.
84 | embed_dim (int): Patch embedding dimension.
85 | depth (int): Depth of ViT.
86 | num_heads (int): Number of attention heads in each ViT block.
87 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
88 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
89 | norm_layer (nn.Module): Normalization layer.
90 | act_layer (nn.Module): Activation layer.
91 | use_abs_pos (bool): If True, use absolute positional embeddings.
92 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
93 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
94 | window_size (int): Window size for window attention blocks.
95 | global_attn_indexes (list): Indexes for blocks using global attention.
96 | """
97 | super().__init__()
98 | self.img_size = img_size
99 |
100 | self.patch_embed = PatchEmbed(
101 | kernel_size=(patch_size, patch_size),
102 | stride=(patch_size, patch_size),
103 | in_chans=in_chans,
104 | embed_dim=embed_dim,
105 | )
106 |
107 | self.pos_embed: Optional[nn.Parameter] = None
108 | if use_abs_pos:
109 | # Initialize absolute positional embedding with pretrain image size.
110 | self.pos_embed = nn.Parameter(
111 | torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim)
112 | )
113 |
114 | self.blocks = nn.ModuleList()
115 | for i in range(depth):
116 | block = Block(
117 | dim=embed_dim,
118 | num_heads=num_heads,
119 | mlp_ratio=mlp_ratio,
120 | qkv_bias=qkv_bias,
121 | norm_layer=norm_layer,
122 | act_layer=act_layer,
123 | use_rel_pos=use_rel_pos,
124 | rel_pos_zero_init=rel_pos_zero_init,
125 | window_size=window_size if i not in global_attn_indexes else 0,
126 | input_size=(img_size // patch_size, img_size // patch_size),
127 | )
128 | self.blocks.append(block)
129 |
130 | self.neck = nn.Sequential(
131 | nn.Conv2d(
132 | embed_dim,
133 | out_chans,
134 | kernel_size=1,
135 | bias=False,
136 | ),
137 | LayerNorm2d(out_chans),
138 | nn.Conv2d(
139 | out_chans,
140 | out_chans,
141 | kernel_size=3,
142 | padding=1,
143 | bias=False,
144 | ),
145 | LayerNorm2d(out_chans),
146 | )
147 |
148 |
149 | self.net_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False)
150 | self.net_3 = nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, bias=False)
151 |
152 | def forward(self, x: torch.Tensor) -> torch.Tensor:
153 | x = self.patch_embed(x)
154 | if self.pos_embed is not None:
155 | x = x + self.pos_embed
156 |
157 | for blk in self.blocks:
158 | x = blk(x)
159 |
160 | x = self.neck(x.permute(0, 3, 1, 2))
161 | x = self.net_2(x)
162 | x = self.net_3(x)
163 |
164 |
165 | return x
166 |
167 |
168 | class Block(nn.Module):
169 | """Transformer blocks with support of window attention and residual propagation blocks"""
170 |
171 | def __init__(
172 | self,
173 | dim: int,
174 | num_heads: int,
175 | mlp_ratio: float = 4.0,
176 | qkv_bias: bool = True,
177 | norm_layer: Type[nn.Module] = nn.LayerNorm,
178 | act_layer: Type[nn.Module] = nn.GELU,
179 | use_rel_pos: bool = False,
180 | rel_pos_zero_init: bool = True,
181 | window_size: int = 0,
182 | input_size: Optional[Tuple[int, int]] = None,
183 | ) -> None:
184 | """
185 | Args:
186 | dim (int): Number of input channels.
187 | num_heads (int): Number of attention heads in each ViT block.
188 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
189 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
190 | norm_layer (nn.Module): Normalization layer.
191 | act_layer (nn.Module): Activation layer.
192 | use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
193 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
194 | window_size (int): Window size for window attention blocks. If it equals 0, then
195 | use global attention.
196 | input_size (tuple(int, int) or None): Input resolution for calculating the relative
197 | positional parameter size.
198 | """
199 | super().__init__()
200 | self.norm1 = norm_layer(dim)
201 | self.attn = Attention(
202 | dim,
203 | num_heads=num_heads,
204 | qkv_bias=qkv_bias,
205 | use_rel_pos=use_rel_pos,
206 | rel_pos_zero_init=rel_pos_zero_init,
207 | input_size=input_size if window_size == 0 else (window_size, window_size),
208 | )
209 |
210 | self.norm2 = norm_layer(dim)
211 | self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
212 |
213 | self.window_size = window_size
214 |
215 | def forward(self, x: torch.Tensor) -> torch.Tensor:
216 | shortcut = x
217 | x = self.norm1(x)
218 | # Window partition
219 | if self.window_size > 0:
220 | H, W = x.shape[1], x.shape[2]
221 | x, pad_hw = window_partition(x, self.window_size)
222 |
223 | x = self.attn(x)
224 | # Reverse window partition
225 | if self.window_size > 0:
226 | x = window_unpartition(x, self.window_size, pad_hw, (H, W))
227 |
228 | x = shortcut + x
229 | x = x + self.mlp(self.norm2(x))
230 |
231 | return x
232 |
233 |
234 | class Attention(nn.Module):
235 | """Multi-head Attention block with relative position embeddings."""
236 |
237 | def __init__(
238 | self,
239 | dim: int,
240 | num_heads: int = 8,
241 | qkv_bias: bool = True,
242 | use_rel_pos: bool = False,
243 | rel_pos_zero_init: bool = True,
244 | input_size: Optional[Tuple[int, int]] = None,
245 | ) -> None:
246 | """
247 | Args:
248 | dim (int): Number of input channels.
249 | num_heads (int): Number of attention heads.
250 | qkv_bias (bool): If True, add a learnable bias to query, key, value.
251 | rel_pos (bool): If True, add relative positional embeddings to the attention map.
252 | rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
253 | input_size (tuple(int, int) or None): Input resolution for calculating the relative
254 | positional parameter size.
255 | """
256 | super().__init__()
257 | self.num_heads = num_heads
258 | head_dim = dim // num_heads
259 | self.scale = head_dim**-0.5
260 |
261 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
262 | self.proj = nn.Linear(dim, dim)
263 |
264 | self.use_rel_pos = use_rel_pos
265 | if self.use_rel_pos:
266 | assert (
267 | input_size is not None
268 | ), "Input size must be provided if using relative positional encoding."
269 | # initialize relative positional embeddings
270 | self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
271 | self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
272 |
273 | def forward(self, x: torch.Tensor) -> torch.Tensor:
274 | B, H, W, _ = x.shape
275 | # qkv with shape (3, B, nHead, H * W, C)
276 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
277 | # q, k, v with shape (B * nHead, H * W, C)
278 | q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
279 |
280 | attn = (q * self.scale) @ k.transpose(-2, -1)
281 |
282 | if self.use_rel_pos:
283 | attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
284 |
285 | attn = attn.softmax(dim=-1)
286 | x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
287 | x = self.proj(x)
288 |
289 | return x
290 |
291 |
292 | def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]:
293 | """
294 | Partition into non-overlapping windows with padding if needed.
295 | Args:
296 | x (tensor): input tokens with [B, H, W, C].
297 | window_size (int): window size.
298 |
299 | Returns:
300 | windows: windows after partition with [B * num_windows, window_size, window_size, C].
301 | (Hp, Wp): padded height and width before partition
302 | """
303 | B, H, W, C = x.shape
304 |
305 | pad_h = (window_size - H % window_size) % window_size
306 | pad_w = (window_size - W % window_size) % window_size
307 | if pad_h > 0 or pad_w > 0:
308 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
309 | Hp, Wp = H + pad_h, W + pad_w
310 |
311 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
312 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
313 | return windows, (Hp, Wp)
314 |
315 |
316 | def window_unpartition(
317 | windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int]
318 | ) -> torch.Tensor:
319 | """
320 | Window unpartition into original sequences and removing padding.
321 | Args:
322 | windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
323 | window_size (int): window size.
324 | pad_hw (Tuple): padded height and width (Hp, Wp).
325 | hw (Tuple): original height and width (H, W) before padding.
326 |
327 | Returns:
328 | x: unpartitioned sequences with [B, H, W, C].
329 | """
330 | Hp, Wp = pad_hw
331 | H, W = hw
332 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
333 | x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1)
334 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
335 |
336 | if Hp > H or Wp > W:
337 | x = x[:, :H, :W, :].contiguous()
338 | return x
339 |
340 |
341 | def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
342 | """
343 | Get relative positional embeddings according to the relative positions of
344 | query and key sizes.
345 | Args:
346 | q_size (int): size of query q.
347 | k_size (int): size of key k.
348 | rel_pos (Tensor): relative position embeddings (L, C).
349 |
350 | Returns:
351 | Extracted positional embeddings according to relative positions.
352 | """
353 | max_rel_dist = int(2 * max(q_size, k_size) - 1)
354 | # Interpolate rel pos if needed.
355 | if rel_pos.shape[0] != max_rel_dist:
356 | # Interpolate rel pos.
357 | rel_pos_resized = F.interpolate(
358 | rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
359 | size=max_rel_dist,
360 | mode="linear",
361 | )
362 | rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
363 | else:
364 | rel_pos_resized = rel_pos
365 |
366 | # Scale the coords with short length if shapes for q and k are different.
367 | q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
368 | k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
369 | relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
370 |
371 | return rel_pos_resized[relative_coords.long()]
372 |
373 |
374 | def add_decomposed_rel_pos(
375 | attn: torch.Tensor,
376 | q: torch.Tensor,
377 | rel_pos_h: torch.Tensor,
378 | rel_pos_w: torch.Tensor,
379 | q_size: Tuple[int, int],
380 | k_size: Tuple[int, int],
381 | ) -> torch.Tensor:
382 | """
383 | Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
384 | https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
385 | Args:
386 | attn (Tensor): attention map.
387 | q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
388 | rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
389 | rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
390 | q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
391 | k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
392 |
393 | Returns:
394 | attn (Tensor): attention map with added relative positional embeddings.
395 | """
396 | q_h, q_w = q_size
397 | k_h, k_w = k_size
398 | Rh = get_rel_pos(q_h, k_h, rel_pos_h)
399 | Rw = get_rel_pos(q_w, k_w, rel_pos_w)
400 |
401 | B, _, dim = q.shape
402 | r_q = q.reshape(B, q_h, q_w, dim)
403 | rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
404 | rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
405 |
406 | attn = (
407 | attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
408 | ).view(B, q_h * q_w, k_h * k_w)
409 |
410 | return attn
411 |
412 |
413 | class PatchEmbed(nn.Module):
414 | """
415 | Image to Patch Embedding.
416 | """
417 |
418 | def __init__(
419 | self,
420 | kernel_size: Tuple[int, int] = (16, 16),
421 | stride: Tuple[int, int] = (16, 16),
422 | padding: Tuple[int, int] = (0, 0),
423 | in_chans: int = 3,
424 | embed_dim: int = 768,
425 | ) -> None:
426 | """
427 | Args:
428 | kernel_size (Tuple): kernel size of the projection layer.
429 | stride (Tuple): stride of the projection layer.
430 | padding (Tuple): padding size of the projection layer.
431 | in_chans (int): Number of input image channels.
432 | embed_dim (int): Patch embedding dimension.
433 | """
434 | super().__init__()
435 |
436 | self.proj = nn.Conv2d(
437 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
438 | )
439 |
440 | def forward(self, x: torch.Tensor) -> torch.Tensor:
441 | x = self.proj(x)
442 | # B C H W -> B H W C
443 | x = x.permute(0, 2, 3, 1)
444 | return x
445 |
446 |
447 |
448 | def build_sam_vit_b(checkpoint=None):
449 | return _build_sam(
450 | encoder_embed_dim=768,
451 | encoder_depth=12,
452 | encoder_num_heads=12,
453 | encoder_global_attn_indexes=[2, 5, 8, 11],
454 | checkpoint=checkpoint,
455 | )
456 |
457 |
458 | def _build_sam(
459 | encoder_embed_dim,
460 | encoder_depth,
461 | encoder_num_heads,
462 | encoder_global_attn_indexes,
463 | checkpoint=None,
464 | ):
465 | prompt_embed_dim = 256
466 | image_size = 1024
467 | vit_patch_size = 16
468 | image_embedding_size = image_size // vit_patch_size
469 | image_encoder=ImageEncoderViT(
470 | depth=encoder_depth,
471 | embed_dim=encoder_embed_dim,
472 | img_size=image_size,
473 | mlp_ratio=4,
474 | norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
475 | num_heads=encoder_num_heads,
476 | patch_size=vit_patch_size,
477 | qkv_bias=True,
478 | use_rel_pos=True,
479 | global_attn_indexes=encoder_global_attn_indexes,
480 | window_size=14,
481 | out_chans=prompt_embed_dim,
482 | )
483 |
484 | if checkpoint is not None:
485 | # with open(checkpoint, "rb") as f:
486 | state_dict = torch.load(checkpoint)
487 |
488 |
489 | image_encoder.load_state_dict(state_dict, strict=True)
490 | # image_encoder.load_state_dict({k[19:]: v for k, v in state_dict.items() if 'vision_tower' in k}, strict=True)
491 | print(checkpoint)
492 | return image_encoder
493 |
494 |
495 |
496 |
497 | if __name__ == '__main__':
498 |
499 | x = torch.zeros(2, 3, 1024, 1024)
500 |
501 | # x.permute(0, 3, 1, 2)
502 |
503 | net = build_sam_vit_b(checkpoint ='')
504 |
505 |
--------------------------------------------------------------------------------
/Vary-master/vary/train/train_flash_attn.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4 |
5 | # Need to call this before importing transformers.
6 | from vary.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7 |
8 | replace_llama_attn_with_flash_attn()
9 |
10 | from vary.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/Vary-master/vary/train/train_lora.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ucas-HaoranWei/Vary-toy/c6e405977aa6e88d4807a6b79f394299623fc6f3/Vary-master/vary/train/train_lora.py
--------------------------------------------------------------------------------
/Vary-master/vary/train/train_lora_flash_attn.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
4 |
5 | # Need to call this before importing transformers.
6 | from vary.utils.llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn
7 |
8 | replace_llama_attn_with_flash_attn()
9 |
10 | # from vary.train.train import train
11 | from vary.train.train_lora import train
12 |
13 | if __name__ == "__main__":
14 | train()
15 |
--------------------------------------------------------------------------------
/Vary-master/vary/train/train_opt.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import logging
18 | import pathlib
19 | import torch
20 | import transformers
21 |
22 | from vary.train.trainer_vit_fixlr import varyTrainer
23 | from vary.model import *
24 | from vary.data import make_supervised_data_module
25 | from vary.utils.arguments import *
26 | from vary.utils.constants import *
27 | from vary.model.vision_encoder.sam import build_sam_vit_b
28 |
29 | def train():
30 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
31 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
32 |
33 |
34 | tokenizer = transformers.AutoTokenizer.from_pretrained(model_args.model_name_or_path, use_fast=False, padding_side="right", model_max_length=training_args.model_max_length)
35 |
36 |
37 | model = varyOPTForCausalLM.from_pretrained(model_args.model_name_or_path)
38 |
39 |
40 |
41 | dtype = torch.float32
42 | if training_args.fp16:
43 | dtype = torch.float16
44 | if training_args.bf16:
45 | dtype = torch.bfloat16
46 |
47 | vision_tower_dict = model.get_model().initialize_vision_modules(
48 | vision_tower=model_args.vision_tower,
49 | pretrained_stage1_model=model_args.pretrained_stage1_model,
50 | freeze_vision_tower=model_args.freeze_vision_tower,
51 | use_im_start_end=model_args.use_im_start_end,
52 | vision_select_layer=model_args.vision_select_layer,
53 | dtype=dtype,
54 | device=training_args.device
55 | )
56 |
57 | model.initialize_vision_tokenizer(
58 | tokenizer=tokenizer,
59 | freeze_lm_model=model_args.freeze_lm_model,
60 | pretrained_stage1_model=model_args.pretrained_stage1_model,
61 | device=training_args.device,
62 | )
63 |
64 |
65 |
66 | model.to(dtype=dtype, device=training_args.device)
67 |
68 | data_args.image_token_len = 256
69 | data_args.image_processor = vision_tower_dict['image_processor']
70 | data_args.image_processor_high = vision_tower_dict['image_processor_high']
71 | data_args.use_im_start_end = model_args.use_im_start_end
72 |
73 | # mixed relation, to be fixed
74 | if model_args.freeze_lm_model:
75 | model.requires_grad_(False)
76 | for p in model.get_model().mm_projector.parameters():
77 | p.requires_grad = True
78 |
79 | for p in model.get_input_embeddings().parameters():
80 | p.requires_grad = True
81 |
82 |
83 | if not model_args.freeze_vision_tower:
84 |
85 | model.get_model().vision_tower.requires_grad_(True)
86 |
87 |
88 | params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad]
89 | print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M")
90 |
91 | # params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
92 | # if len(params_no_grad) > 0:
93 | # if training_args.fsdp is not None and len(training_args.fsdp) > 0:
94 | # if len(params_no_grad) < 10:
95 | # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
96 | # else:
97 | # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
98 | # print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
99 | # print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
100 |
101 | # from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
102 | # def patch_FSDP_use_orig_params(func):
103 | # def wrap_func(*args, **kwargs):
104 | # use_orig_params = kwargs.pop('use_orig_params', True)
105 | # return func(*args, **kwargs, use_orig_params=use_orig_params)
106 | # return wrap_func
107 |
108 | # FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
109 |
110 | # interleave = True
111 | data_module = make_supervised_data_module(
112 | interleave=training_args.interleave,
113 | with_box=training_args.with_box,
114 | tokenizer=tokenizer,
115 | data_args=data_args
116 | )
117 |
118 | trainer = varyTrainer(
119 | model=model,
120 | tokenizer=tokenizer,
121 | args=training_args,
122 | **data_module)
123 |
124 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
125 | trainer.train(resume_from_checkpoint=True)
126 | else:
127 | trainer.train()
128 | trainer.save_state()
129 | trainer._safe_save(output_dir=training_args.output_dir)
130 |
131 |
132 | if __name__ == "__main__":
133 | train()
134 |
135 |
--------------------------------------------------------------------------------
/Vary-master/vary/train/train_qwen_vary.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import logging
18 | import pathlib
19 | import torch
20 | import transformers
21 |
22 |
23 | from vary.train.trainer_vit_fixlr import varyTrainer
24 | from vary.model import *
25 | from vary.data import make_supervised_data_module
26 | from vary.utils.arguments import *
27 | from vary.utils.utils import smart_tokenizer_and_embedding_resize
28 | from vary.model.vision_encoder.sam import build_sam_vit_b
29 |
30 |
31 | def train():
32 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
33 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
34 |
35 | tokenizer = transformers.AutoTokenizer.from_pretrained("model_args.model_name_or_path", trust_remote_code=True, padding_side="right", model_max_length=training_args.model_max_length,)
36 |
37 |
38 | model = varyQwenForCausalLM.from_pretrained(model_args.model_name_or_path, low_cpu_mem_usage=True, device_map='cuda')
39 |
40 |
41 |
42 | smart_tokenizer_and_embedding_resize(
43 | special_tokens_dict=dict(pad_token='<|endoftext|>'),
44 | tokenizer=tokenizer,
45 | model=model,
46 | )
47 |
48 |
49 | dtype = torch.float32
50 | if training_args.fp16:
51 | dtype = torch.float16
52 | if training_args.bf16:
53 | dtype = torch.bfloat16
54 |
55 | vision_tower_dict = model.get_model().initialize_vision_modules(
56 | vision_tower=model_args.vision_tower,
57 | pretrained_stage1_model=model_args.pretrained_stage1_model,
58 | freeze_vision_tower=model_args.freeze_vision_tower,
59 | use_im_start_end=model_args.use_im_start_end,
60 | vision_select_layer=model_args.vision_select_layer,
61 | dtype=dtype,
62 | device=training_args.device
63 | )
64 |
65 | model.initialize_vision_tokenizer(
66 | tokenizer=tokenizer,
67 | freeze_lm_model=model_args.freeze_lm_model,
68 | pretrained_stage1_model=model_args.pretrained_stage1_model,
69 | device=training_args.device,
70 | )
71 |
72 |
73 |
74 |
75 | model.to(dtype=dtype, device=training_args.device)
76 |
77 | data_args.image_token_len = 256
78 | data_args.image_processor = vision_tower_dict['image_processor']
79 | data_args.image_processor_high = vision_tower_dict['image_processor_high']
80 | data_args.use_im_start_end = model_args.use_im_start_end
81 |
82 | # mixed relation, to be fixed
83 | if model_args.freeze_lm_model:
84 | model.requires_grad_(False)
85 | for p in model.get_model().mm_projector.parameters():
86 | p.requires_grad = True
87 | for p in model.get_model().mm_projector_vary.parameters():
88 | p.requires_grad = True
89 | for p in model.get_input_embeddings().parameters():
90 | p.requires_grad = True
91 |
92 |
93 | if not model_args.freeze_vision_tower:
94 | model.get_model().vision_tower.requires_grad_(True)
95 | model.get_model().vision_tower_high.requires_grad_(True)
96 |
97 |
98 | params_grad = [p.numel() for n, p in model.named_parameters() if p.requires_grad]
99 | print(f"Number of Mapping Trainable Parameters: {sum(params_grad) / (1 << 20):.2f} M")
100 |
101 | # params_no_grad = [n for n, p in model.named_parameters() if not p.requires_grad]
102 | # if len(params_no_grad) > 0:
103 | # if training_args.fsdp is not None and len(training_args.fsdp) > 0:
104 | # if len(params_no_grad) < 10:
105 | # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}'. format(len(params_no_grad), params_no_grad))
106 | # else:
107 | # print('[WARNING] Attempting to use FSDP while {} parameters do not require gradients: {}...(omitted)'. format(len(params_no_grad), ', '.join(params_no_grad[:10])))
108 | # print("[WARNING] Attempting to use FSDP with partially frozen paramters, this is experimental.")
109 | # print("[WARNING] As of 4/30/23, this feature requires PyTorch-nightly build. See here for details: https://github.com/haotian-liu/LLaVA#experimental-use-fsdp-to-save-memory-in-pretraining")
110 |
111 | # from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
112 | # def patch_FSDP_use_orig_params(func):
113 | # def wrap_func(*args, **kwargs):
114 | # use_orig_params = kwargs.pop('use_orig_params', True)
115 | # return func(*args, **kwargs, use_orig_params=use_orig_params)
116 | # return wrap_func
117 |
118 | # FSDP.__init__ = patch_FSDP_use_orig_params(FSDP.__init__)
119 |
120 |
121 |
122 | data_module = make_supervised_data_module(
123 | interleave=training_args.interleave,
124 | with_box=training_args.with_box,
125 | tokenizer=tokenizer,
126 | data_args=data_args
127 | )
128 |
129 | trainer = varyTrainer(
130 | model=model,
131 | tokenizer=tokenizer,
132 | args=training_args,
133 | **data_module)
134 |
135 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
136 | trainer.train(resume_from_checkpoint=True)
137 | else:
138 | trainer.train()
139 | trainer.save_state()
140 | trainer._safe_save(output_dir=training_args.output_dir)
141 |
142 |
143 | if __name__ == "__main__":
144 | train()
145 |
--------------------------------------------------------------------------------
/Vary-master/vary/train/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | from transformers import Trainer
6 | from typing import Dict, Optional, Sequence
7 |
8 |
9 | def unwrap_model(model: nn.Module) -> nn.Module:
10 | """
11 | Recursively unwraps a model from potential containers (as used in distributed training).
12 |
13 | Args:
14 | model (`torch.nn.Module`): The model to unwrap.
15 | """
16 | # since there could be multiple levels of wrapping, unwrap recursively
17 | if hasattr(model, "module"):
18 | return unwrap_model(model.module)
19 | else:
20 | return model
21 |
22 |
23 | class varyTrainer(Trainer):
24 |
25 | def _safe_save(self, output_dir: str):
26 | """Collects the state dict and dump to disk."""
27 | if self.deepspeed:
28 | torch.cuda.synchronize()
29 | self.save_model(output_dir)
30 | return
31 |
32 | state_dict = self.model.state_dict()
33 | if self.args.should_save:
34 | cpu_state_dict = {
35 | key: value.cpu()
36 | for key, value in state_dict.items()
37 | }
38 | del state_dict
39 | self._save(output_dir, state_dict=cpu_state_dict) # noqa
40 |
41 |
42 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
43 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
44 | # Save the model
45 | _state_dict = state_dict
46 | if _state_dict is None:
47 | # Only save the model itself if we are using distributed training
48 | model_to_save = unwrap_model(self.model)
49 | _state_dict = model_to_save.state_dict()
50 |
51 | weight_to_save = {}
52 | keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
53 | for k, v in _state_dict.items():
54 | if any(key_match in k for key_match in keys_to_match):
55 | weight_to_save[k] = v
56 |
57 | current_folder = output_dir.split('/')[-1]
58 | parent_folder = os.path.dirname(output_dir)
59 | if current_folder.startswith('checkpoint-'):
60 | mm_projector_folder = os.path.join(parent_folder, "mm_projector")
61 | os.makedirs(mm_projector_folder, exist_ok=True)
62 | torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
63 | else:
64 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
65 |
66 | super(varyTrainer, self)._save(output_dir, state_dict)
67 |
--------------------------------------------------------------------------------
/Vary-master/vary/train/trainer_vit_fixlr.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 | from transformers import Trainer
6 | from transformers.trainer_pt_utils import get_parameter_names
7 | from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
8 | from typing import Dict, Optional, Sequence
9 |
10 |
11 | def unwrap_model(model: nn.Module) -> nn.Module:
12 | """
13 | Recursively unwraps a model from potential containers (as used in distributed training).
14 |
15 | Args:
16 | model (`torch.nn.Module`): The model to unwrap.
17 | """
18 | # since there could be multiple levels of wrapping, unwrap recursively
19 | if hasattr(model, "module"):
20 | return unwrap_model(model.module)
21 | else:
22 | return model
23 |
24 |
25 | class varyTrainer(Trainer):
26 |
27 | def _safe_save(self, output_dir: str):
28 | """Collects the state dict and dump to disk."""
29 | state_dict = self.model.state_dict()
30 | if self.args.should_save:
31 | cpu_state_dict = {
32 | key: value.cpu()
33 | for key, value in state_dict.items()
34 | }
35 | del state_dict
36 | self._save(output_dir, state_dict=cpu_state_dict) # noqa
37 |
38 |
39 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
40 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
41 | # Save the model
42 | _state_dict = state_dict
43 | if _state_dict is None:
44 | # Only save the model itself if we are using distributed training
45 | model_to_save = unwrap_model(self.model)
46 | _state_dict = model_to_save.state_dict()
47 |
48 | weight_to_save = {}
49 | keys_to_match = ['mm_projector', 'embed_tokens', 'embed_in']
50 | for k, v in _state_dict.items():
51 | if any(key_match in k for key_match in keys_to_match):
52 | weight_to_save[k] = v
53 |
54 | current_folder = output_dir.split('/')[-1]
55 | parent_folder = os.path.dirname(output_dir)
56 | if current_folder.startswith('checkpoint-'):
57 | mm_projector_folder = os.path.join(parent_folder, "mm_projector")
58 | os.makedirs(mm_projector_folder, exist_ok=True)
59 | torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin'))
60 | else:
61 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
62 |
63 | super(varyTrainer, self)._save(output_dir, state_dict)
64 |
65 | def create_optimizer(self):
66 | """
67 | Setup the optimizer.
68 |
69 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
70 | Trainer's init through `optimizers`, or subclass and override this method in a subclass.
71 | """
72 | opt_model = self.model
73 |
74 | if self.optimizer is None:
75 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
76 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
77 | optimizer_grouped_parameters = [
78 | {
79 | "params": [
80 | p for n, p in opt_model.named_parameters() if 'vision_encoder' in n and n in decay_parameters and p.requires_grad
81 | ],
82 | "weight_decay": self.args.weight_decay,
83 | "lr": self.args.learning_rate,
84 | },
85 | {
86 | "params": [
87 | p for n, p in opt_model.named_parameters() if 'vision_encoder' in n and n not in decay_parameters and p.requires_grad],
88 | "weight_decay": 0.0,
89 | "lr": self.args.learning_rate,
90 | },
91 | {
92 | "params": [
93 | p for n, p in opt_model.named_parameters() if 'vision_encoder' not in n and n in decay_parameters and p.requires_grad],
94 | "weight_decay": self.args.weight_decay,
95 | "lr": self.args.learning_rate,
96 | },
97 | {
98 | "params": [
99 | p for n, p in opt_model.named_parameters() if 'vision_encoder' not in n and n not in decay_parameters and p.requires_grad
100 | ],
101 | "weight_decay": 0.0,
102 | "lr": self.args.learning_rate,
103 | },
104 | ]
105 | for idx, group in enumerate(optimizer_grouped_parameters):
106 | print(idx, len(group['params']), group['lr'])
107 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
108 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
109 |
110 | return self.optimizer
--------------------------------------------------------------------------------
/Vary-master/vary/utils/arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Dict, Optional, Sequence
3 | import transformers
4 |
5 |
6 | @dataclass
7 | class ModelArguments:
8 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
9 | use_cache: bool = field(default=False)
10 | vision_tower: Optional[str] = field(default="~/.cache/huggingface/hub/models--openai--clip-vit-large-patch14/snapshots/8d052a0f05efbaefbc9e8786ba291cfdf93e5bff/")
11 | freeze_vision_tower: bool = field(default=False)
12 | freeze_lm_model: bool = field(default=False)
13 | pretrained_stage1_model: Optional[str] = field(default=None) # mlp &/ vision tower
14 | vision_select_layer: Optional[int] = field(default=-1) # default to the last layer
15 | use_im_start_end: bool = field(default=False)
16 |
17 |
18 | @dataclass
19 | class DataArguments:
20 | datasets: str = field(default=None, metadata={"help": "combinations of the training data."})
21 | sep_image_conv_front: bool = False
22 | image_token_len: int = 256
23 | image_aspect_ratio: str = 'square'
24 | conversation_version: str = 'mpt'
25 | # conversation_version: str = 'v0'
26 | # conversation_version: str = 'v1'
27 | # conversation_version: str = 'opt'
28 | box_limit: int = 0
29 |
30 |
31 | @dataclass
32 | class TrainingArguments(transformers.TrainingArguments):
33 | cache_dir: Optional[str] = field(default=None)
34 | optim: str = field(default="adamw_torch")
35 | remove_unused_columns: bool = field(default=False)
36 | force_fsdp: bool = field(default=False)
37 | interleave: bool = field(default=False)
38 | with_box: bool = field(default=False)
39 | model_max_length: int = field(
40 | default=512,
41 | metadata={
42 | "help":
43 | "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
44 | },
45 | )
46 | lora_enable: bool = False
47 | lora_r: int = 8
48 | lora_alpha: int = 16
49 | lora_dropout: float = 0.05
50 | lora_weight_path: str = ""
51 | lora_bias: str = "none"
--------------------------------------------------------------------------------
/Vary-master/vary/utils/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "log"
5 |
6 | IGNORE_INDEX = -100
7 | # DEFAULT_PAD_TOKEN = "[PAD]"
8 |
9 | DEFAULT_PAD_TOKEN = "<|endoftext|>"
10 | DEFAULT_EOS_TOKEN = ""
11 | DEFAULT_BOS_TOKEN = ""
12 | DEFAULT_UNK_TOKEN = ""
13 | DEFAULT_IMAGE_TOKEN = ""
14 | DEFAULT_BOX_TOKEN = ""
15 |
16 | DEFAULT_IMAGE_PATCH_TOKEN = ''
17 |
18 | DEFAULT_IM_START_TOKEN = '
'
19 | DEFAULT_IM_END_TOKEN = ''
20 |
21 |
22 | ROOT_PATH = '/data/public/ucaswei/data/'
23 |
24 | CONVERSATION_DATA = {
25 |
26 | # pair 4m
27 | 'laion-coco-4m': {
28 | 'images': '',
29 | 'annotations': '',
30 | },
31 |
32 | 'cc665k': {
33 | 'images': "/path_to/LLaVA1.5/images/",
34 | 'annotations': "/path_to/LLaVA1.5/llava_v1_5_66k.json",
35 | },
36 |
37 | 'pdf': {
38 | 'images': "",
39 | 'annotations': "",
40 | },
41 |
42 | 'docvqa_train': {
43 | 'images': "",
44 | 'annotations': "",
45 | },
46 |
47 | 'chartqa_train': {
48 | 'images': "",
49 | 'annotations': "",
50 | },
51 |
52 |
53 |
54 | }
--------------------------------------------------------------------------------
/Vary-master/vary/utils/conversation.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from enum import auto, Enum
3 | from typing import List, Tuple
4 |
5 |
6 | class SeparatorStyle(Enum):
7 | """Different separator style."""
8 | SINGLE = auto()
9 | TWO = auto()
10 | MPT = auto()
11 |
12 |
13 |
14 |
15 | @dataclasses.dataclass
16 | class Conversation:
17 | """A class that keeps all conversation history."""
18 | system: str
19 | roles: List[str]
20 | messages: List[List[str]]
21 | offset: int
22 | sep_style: SeparatorStyle = SeparatorStyle.SINGLE
23 | sep: str = "<|im_end|>"
24 | sep2: str = None
25 | version: str = "Unknown"
26 |
27 | skip_next: bool = False
28 |
29 | def get_prompt(self):
30 | if self.sep_style == SeparatorStyle.SINGLE:
31 | ret = self.system + self.sep + '\n'
32 | for role, message in self.messages:
33 | if message:
34 | if type(message) is tuple:
35 | message, _, _ = message
36 | ret += role + ": " + message + self.sep
37 | else:
38 | ret += role + ":"
39 | return ret
40 | elif self.sep_style == SeparatorStyle.TWO:
41 | seps = [self.sep, self.sep2]
42 | ret = self.system + seps[0]
43 | for i, (role, message) in enumerate(self.messages):
44 | if message:
45 | if type(message) is tuple:
46 | message, _, _ = message
47 | ret += role + ": " + message + seps[i % 2]
48 | else:
49 | ret += role + ":"
50 | return ret
51 | if self.sep_style == SeparatorStyle.MPT:
52 | if self.system:
53 | ret = self.system + self.sep
54 | else:
55 | ret = ''
56 | for role, message in self.messages:
57 | if message:
58 | if type(message) is tuple:
59 | message, _, _ = message
60 | ret += role + message + self.sep
61 | else:
62 | ret += role
63 | return ret
64 | else:
65 | raise ValueError(f"Invalid style: {self.sep_style}")
66 |
67 |
68 | def append_message(self, role, message):
69 | self.messages.append([role, message])
70 |
71 | def get_images(self, return_pil=False):
72 | images = []
73 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
74 | if i % 2 == 0:
75 | if type(msg) is tuple:
76 | import base64
77 | from io import BytesIO
78 | from PIL import Image
79 | msg, image, image_process_mode = msg
80 | if image_process_mode == "Pad":
81 | def expand2square(pil_img, background_color=(122, 116, 104)):
82 | width, height = pil_img.size
83 | if width == height:
84 | return pil_img
85 | elif width > height:
86 | result = Image.new(pil_img.mode, (width, width), background_color)
87 | # result.paste(pil_img, (0, (width - height) // 2))
88 | result.paste(pil_img)
89 | return result
90 | else:
91 | result = Image.new(pil_img.mode, (height, height), background_color)
92 | # result.paste(pil_img, ((height - width) // 2, 0))
93 | result.paste(pil_img)
94 | return result
95 | image = expand2square(image)
96 | elif image_process_mode == "Crop":
97 | max_hw, min_hw = max(image.size), min(image.size)
98 | aspect_ratio = max_hw / min_hw
99 | max_len, min_len = 800, 400
100 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
101 | longest_edge = int(shortest_edge * aspect_ratio)
102 | W, H = image.size
103 | if H > W:
104 | H, W = longest_edge, shortest_edge
105 | else:
106 | H, W = shortest_edge, longest_edge
107 | image = image.resize((W, H))
108 | elif image_process_mode == "Resize":
109 | image = image.resize((224, 224))
110 | else:
111 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
112 |
113 | if return_pil:
114 | images.append(image)
115 | else:
116 | buffered = BytesIO()
117 | image.convert('RGB').save(buffered, format="JPEG")
118 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
119 | images.append(img_b64_str)
120 | return images
121 |
122 | def to_gradio_chatbot(self):
123 | ret = []
124 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
125 | if i % 2 == 0:
126 | if type(msg) is tuple:
127 | import base64
128 | from io import BytesIO
129 | msg, image, image_process_mode = msg
130 | max_hw, min_hw = max(image.size), min(image.size)
131 | aspect_ratio = max_hw / min_hw
132 | max_len, min_len = 800, 400
133 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
134 | longest_edge = int(shortest_edge * aspect_ratio)
135 | W, H = image.size
136 | if H > W:
137 | H, W = longest_edge, shortest_edge
138 | else:
139 | H, W = shortest_edge, longest_edge
140 | image = image.resize((W, H))
141 | # image = image.resize((224, 224))
142 | buffered = BytesIO()
143 | image.save(buffered, format="JPEG")
144 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
145 | img_str = f'
'
146 | msg = msg.replace('', img_str)
147 | ret.append([msg, None])
148 | else:
149 | ret[-1][-1] = msg
150 | return ret
151 |
152 | def copy(self):
153 | return Conversation(
154 | system=self.system,
155 | roles=self.roles,
156 | messages=[[x, y] for x, y in self.messages],
157 | offset=self.offset,
158 | sep_style=self.sep_style,
159 | sep=self.sep,
160 | sep2=self.sep2)
161 |
162 | def dict(self):
163 | if len(self.get_images()) > 0:
164 | return {
165 | "system": self.system,
166 | "roles": self.roles,
167 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
168 | "offset": self.offset,
169 | "sep": self.sep,
170 | "sep2": self.sep2,
171 | }
172 | return {
173 | "system": self.system,
174 | "roles": self.roles,
175 | "messages": self.messages,
176 | "offset": self.offset,
177 | "sep": self.sep,
178 | "sep2": self.sep2,
179 | }
180 |
181 |
182 | conv_v1 = Conversation(
183 | system="A chat between a curious human and an artificial intelligence assistant. "
184 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
185 | roles=("Human", "Assistant"),
186 | messages=(
187 | ("Human", "Give three tips for staying healthy."),
188 | ("Assistant",
189 | "Sure, here are three tips for staying healthy:\n"
190 | "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
191 | "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
192 | "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
193 | "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
194 | "activities at least two days per week.\n"
195 | "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
196 | "vegetables, whole grains, lean proteins, and healthy fats can help support "
197 | "your overall health. Try to limit your intake of processed and high-sugar foods, "
198 | "and aim to drink plenty of water throughout the day.\n"
199 | "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
200 | "and mental health. Adults should aim for seven to nine hours of sleep per night. "
201 | "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
202 | "help improve the quality of your sleep.")
203 | ),
204 | offset=2,
205 | sep_style=SeparatorStyle.SINGLE,
206 | sep="###",
207 | )
208 |
209 | conv_v1_2 = Conversation(
210 | system="A chat between a curious human and an artificial intelligence assistant. "
211 | "The assistant gives helpful, detailed, and polite answers to the human's questions.",
212 | roles=("Human", "Assistant"),
213 | messages=(
214 | ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
215 | ("Assistant",
216 | "Renewable energy sources are those that can be replenished naturally in a relatively "
217 | "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
218 | "Non-renewable energy sources, on the other hand, are finite and will eventually be "
219 | "depleted, such as coal, oil, and natural gas. Here are some key differences between "
220 | "renewable and non-renewable energy sources:\n"
221 | "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
222 | "energy sources are finite and will eventually run out.\n"
223 | "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
224 | "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
225 | "and other negative effects.\n"
226 | "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
227 | "have lower operational costs than non-renewable sources.\n"
228 | "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
229 | "locations than non-renewable sources.\n"
230 | "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
231 | "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
232 | "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
233 | "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
234 | ),
235 | offset=2,
236 | sep_style=SeparatorStyle.SINGLE,
237 | sep="###",
238 | )
239 |
240 | conv_vicuna_v1_1 = Conversation(
241 | system="A chat between a curious user and an artificial intelligence assistant. "
242 | "The assistant gives helpful, detailed, and polite answers to the user's questions.",
243 | roles=("USER", "ASSISTANT"),
244 | version="v1",
245 | messages=(),
246 | offset=0,
247 | sep_style=SeparatorStyle.TWO,
248 | sep=" ",
249 | sep2="",
250 | )
251 |
252 |
253 |
254 | conv_mpt = Conversation(
255 | system="""<|im_start|>system
256 | You should follow the instructions carefully and explain your answers in detail.""",
257 | # system = None,
258 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
259 | version="mpt",
260 | messages=(),
261 | offset=0,
262 | sep_style=SeparatorStyle.MPT,
263 | sep="<|im_end|>",
264 | )
265 |
266 | conv_mpt_eval = Conversation(
267 | system="",
268 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
269 | version="mpt",
270 | messages=(),
271 | offset=0,
272 | sep_style=SeparatorStyle.MPT,
273 | sep="<|im_end|>",
274 | )
275 |
276 | conv_mpt_text = Conversation(
277 | system="""<|im_start|>system
278 | - You are a helpful assistant chatbot trained by MosaicML.
279 | - You answer questions.
280 | - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
281 | - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
282 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
283 | version="mpt",
284 | messages=(),
285 | offset=0,
286 | sep_style=SeparatorStyle.MPT,
287 | sep="<|im_end|>",
288 | )
289 |
290 | conv_bair_v1 = Conversation(
291 | system="BEGINNING OF CONVERSATION:",
292 | roles=("USER", "GPT"),
293 | messages=(),
294 | offset=0,
295 | sep_style=SeparatorStyle.TWO,
296 | sep=" ",
297 | sep2="",
298 | )
299 |
300 |
301 |
302 |
303 | simple_conv = Conversation(
304 | system="",
305 | roles=("Human", "Assistant"),
306 | messages=(
307 | ),
308 | offset=0,
309 | sep_style=SeparatorStyle.SINGLE,
310 | sep="###",
311 | )
312 |
313 | simple_conv_multimodal = Conversation(
314 | system="You are vary, a large language and vision assistant trained by Foundation Model Group, Megvii Technology."
315 | "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
316 | "Follow the instructions carefully and explain your answers in detail.",
317 | # system="",
318 | roles=("Human", "Assistant"),
319 | messages=(
320 | ("Human", "Hi!"),
321 | ("Assistant", "Hi there! How can I help you today?\n")
322 | ),
323 | offset=2,
324 | sep_style=SeparatorStyle.SINGLE,
325 | sep="###",
326 | )
327 |
328 | simple_conv_mpt_multimodal = Conversation(
329 | system="""<|im_start|>system
330 | - You are vary, a large language and vision assistant trained by Foundation Model Group, Megvii Technology.
331 | - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
332 | - You should follow the instructions carefully and explain your answers in detail.""",
333 | roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
334 | version="mpt",
335 | messages=(),
336 | offset=0,
337 | sep_style=SeparatorStyle.MPT,
338 | sep="<|im_end|>",
339 | )
340 |
341 | simple_conv_legacy = Conversation(
342 | system="You are vary, a large language model trained by Foundation Model Group, Megvii Technology."
343 | "You are designed to assist human with a variety of tasks using natural language."
344 | "Follow the instructions carefully.",
345 | roles=("Human", "Assistant"),
346 | messages=(
347 | ("Human", "Hi!\n\n### Response:"),
348 | ("Assistant", "Hi there! How can I help you today?\n")
349 | ),
350 | offset=2,
351 | sep_style=SeparatorStyle.SINGLE,
352 | sep="###",
353 | )
354 |
355 | conv_llava_v1 = Conversation(
356 | system="You are vary, a large language and vision assistant trained by Foundation Model Group, Megvii Technology."
357 | "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
358 | "Follow the instructions carefully and explain your answers in detail.",
359 | roles=("USER", "ASSISTANT"),
360 | version="v1",
361 | messages=(),
362 | offset=0,
363 | sep_style=SeparatorStyle.TWO,
364 | sep=" ",
365 | sep2="",
366 | )
367 |
368 | default_conversation = conv_mpt
369 | conv_templates = {
370 | "default": simple_conv_multimodal,
371 | "simple": simple_conv,
372 | "simple_legacy": simple_conv_legacy,
373 | "multimodal": simple_conv,
374 | "mpt_multimodal": simple_conv_mpt_multimodal,
375 | "llava_v1": conv_llava_v1,
376 | "mpt_eval": conv_mpt_eval,
377 | # fastchat
378 | "v1": conv_vicuna_v1_1,
379 | "baichuan": conv_vicuna_v1_1,
380 | "bair_v1": conv_bair_v1,
381 | "vicuna_v1_1": conv_vicuna_v1_1,
382 | "mpt": conv_mpt,
383 | "qwen": conv_mpt,
384 | "mpt_text": conv_mpt_text,
385 | }
386 |
387 |
388 | if __name__ == "__main__":
389 | print(default_conversation.get_prompt())
390 |
--------------------------------------------------------------------------------
/Vary-master/vary/utils/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | from typing import List, Optional, Tuple
3 |
4 | import torch
5 | from torch import nn
6 |
7 | import transformers
8 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
9 |
10 | from einops import rearrange
11 |
12 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 | def forward(
16 | self,
17 | hidden_states: torch.Tensor,
18 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
19 | attention_mask: Optional[torch.Tensor] = None,
20 | output_attentions: bool = False,
21 | use_cache: bool = False,
22 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
23 | Optional[Tuple[torch.Tensor]]]:
24 | """Input shape: Batch x Time x Channel
25 |
26 | attention_mask: [bsz, q_len]
27 | """
28 | bsz, q_len, _ = hidden_states.size()
29 |
30 | query_states = self.q_proj(hidden_states).view(
31 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
32 | key_states = self.k_proj(hidden_states).view(
33 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
34 | value_states = self.v_proj(hidden_states).view(
35 | bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
36 | # [bsz, q_len, nh, hd]
37 | # [bsz, nh, q_len, hd]
38 |
39 | kv_seq_len = key_states.shape[-2]
40 | offset = 0
41 | if past_key_value is not None:
42 | offset = past_key_value[0].shape[-2]
43 | kv_seq_len += offset
44 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
45 | query_states, key_states = apply_rotary_pos_emb(query_states,
46 | key_states,
47 | cos,
48 | sin,
49 | offset=offset)
50 | # [bsz, nh, t, hd]
51 | assert not output_attentions, "output_attentions is not supported"
52 | assert not use_cache, "use_cache is not supported"
53 | assert past_key_value is None, "past_key_value is not supported"
54 |
55 | # Flash attention codes from
56 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
57 |
58 | # transform the data into the format required by flash attention
59 | qkv = torch.stack([query_states, key_states, value_states], dim=2) # [bsz, nh, 3, q_len, hd]
60 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
61 | # We have disabled _prepare_decoder_attention_mask in LlamaModel
62 | # the attention_mask should be the same as the key_padding_mask
63 | key_padding_mask = attention_mask
64 |
65 |
66 | if key_padding_mask is None:
67 | qkv = rearrange(qkv, 'b s ... -> (b s) ...')
68 | max_s = q_len
69 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32,
70 | device=qkv.device)
71 | output = flash_attn_unpadded_qkvpacked_func(
72 | qkv, cu_q_lens, max_s, 0.0,
73 | softmax_scale=None, causal=True
74 | )
75 | output = rearrange(output, '(b s) ... -> b s ...', b=bsz)
76 | else:
77 | nheads = qkv.shape[-2]
78 | x = rearrange(qkv, 'b s three h d -> b s (three h d)')
79 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
80 | x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
81 | output_unpad = flash_attn_unpadded_qkvpacked_func(
82 | x_unpad, cu_q_lens, max_s, 0.0,
83 | softmax_scale=None, causal=True
84 | )
85 | output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
86 | indices, bsz, q_len),
87 | 'b s (h d) -> b s h d', h=nheads)
88 | return self.o_proj(rearrange(output,
89 | 'b s h d -> b s (h d)')), None, None
90 |
91 |
92 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
93 | # requires the attention mask to be the same as the key_padding_mask
94 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
95 | inputs_embeds, past_key_values_length):
96 | # [bsz, seq_len]
97 | return attention_mask
98 |
99 |
100 | def replace_llama_attn_with_flash_attn():
101 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask
102 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
103 |
--------------------------------------------------------------------------------
/Vary-master/vary/utils/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 | import torch
7 | import requests
8 |
9 | from transformers import StoppingCriteria
10 | from vary.utils.constants import LOGDIR
11 |
12 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
13 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
14 |
15 | handler = None
16 |
17 |
18 | def build_logger(logger_name, logger_filename):
19 | global handler
20 |
21 | formatter = logging.Formatter(
22 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
23 | datefmt="%Y-%m-%d %H:%M:%S",
24 | )
25 |
26 | # Set the format of root handlers
27 | if not logging.getLogger().handlers:
28 | logging.basicConfig(level=logging.INFO)
29 | logging.getLogger().handlers[0].setFormatter(formatter)
30 |
31 | # Redirect stdout and stderr to loggers
32 | stdout_logger = logging.getLogger("stdout")
33 | stdout_logger.setLevel(logging.INFO)
34 | sl = StreamToLogger(stdout_logger, logging.INFO)
35 | sys.stdout = sl
36 |
37 | stderr_logger = logging.getLogger("stderr")
38 | stderr_logger.setLevel(logging.ERROR)
39 | sl = StreamToLogger(stderr_logger, logging.ERROR)
40 | sys.stderr = sl
41 |
42 | # Get logger
43 | logger = logging.getLogger(logger_name)
44 | logger.setLevel(logging.INFO)
45 |
46 | # Add a file handler for all loggers
47 | if handler is None:
48 | os.makedirs(LOGDIR, exist_ok=True)
49 | filename = os.path.join(LOGDIR, logger_filename)
50 | handler = logging.handlers.TimedRotatingFileHandler(
51 | filename, when='D', utc=True)
52 | handler.setFormatter(formatter)
53 |
54 | for name, item in logging.root.manager.loggerDict.items():
55 | if isinstance(item, logging.Logger):
56 | item.addHandler(handler)
57 |
58 | return logger
59 |
60 |
61 | class StreamToLogger(object):
62 | """
63 | Fake file-like stream object that redirects writes to a logger instance.
64 | """
65 | def __init__(self, logger, log_level=logging.INFO):
66 | self.terminal = sys.stdout
67 | self.logger = logger
68 | self.log_level = log_level
69 | self.linebuf = ''
70 |
71 | def __getattr__(self, attr):
72 | return getattr(self.terminal, attr)
73 |
74 | def write(self, buf):
75 | temp_linebuf = self.linebuf + buf
76 | self.linebuf = ''
77 | for line in temp_linebuf.splitlines(True):
78 | # From the io.TextIOWrapper docs:
79 | # On output, if newline is None, any '\n' characters written
80 | # are translated to the system default line separator.
81 | # By default sys.stdout.write() expects '\n' newlines and then
82 | # translates them so this is still cross platform.
83 | if line[-1] == '\n':
84 | self.logger.log(self.log_level, line.rstrip())
85 | else:
86 | self.linebuf += line
87 |
88 | def flush(self):
89 | if self.linebuf != '':
90 | self.logger.log(self.log_level, self.linebuf.rstrip())
91 | self.linebuf = ''
92 |
93 |
94 | def disable_torch_init():
95 | """
96 | Disable the redundant torch default initialization to accelerate model creation.
97 | """
98 | import torch
99 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
100 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
101 |
102 |
103 | def violates_moderation(text):
104 | """
105 | Check whether the text violates OpenAI moderation API.
106 | """
107 | url = "https://api.openai.com/v1/moderations"
108 | headers = {"Content-Type": "application/json",
109 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
110 | text = text.replace("\n", "")
111 | data = "{" + '"input": ' + f'"{text}"' + "}"
112 | data = data.encode("utf-8")
113 | try:
114 | ret = requests.post(url, headers=headers, data=data, timeout=5)
115 | flagged = ret.json()["results"][0]["flagged"]
116 | except requests.exceptions.RequestException as e:
117 | flagged = False
118 | except KeyError as e:
119 | flagged = False
120 |
121 | return flagged
122 |
123 |
124 | def pretty_print_semaphore(semaphore):
125 | if semaphore is None:
126 | return "None"
127 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
128 |
129 |
130 | class KeywordsStoppingCriteria(StoppingCriteria):
131 | def __init__(self, keywords, tokenizer, input_ids):
132 | self.keywords = keywords
133 | self.keyword_ids = [tokenizer(keyword).input_ids for keyword in keywords]
134 | self.keyword_ids = [keyword_id[0] for keyword_id in self.keyword_ids if type(keyword_id) is list and len(keyword_id) == 1]
135 | self.tokenizer = tokenizer
136 | self.start_len = None
137 | self.input_ids = input_ids
138 |
139 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
140 | if self.start_len is None:
141 | self.start_len = self.input_ids.shape[1]
142 | else:
143 | for keyword_id in self.keyword_ids:
144 | if output_ids[0, -1] == keyword_id:
145 | return True
146 | outputs = self.tokenizer.batch_decode(output_ids[:, self.start_len:], skip_special_tokens=True)[0]
147 | for keyword in self.keywords:
148 | if keyword in outputs:
149 | return True
150 | return False
151 |
152 |
153 | def smart_tokenizer_and_embedding_resize(special_tokens_dict, tokenizer, model):
154 | """Resize tokenizer and embedding.
155 |
156 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
157 | """
158 | # num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
159 | # # num_new_tokens = 1
160 | # # tokenizer.add_tokens(special_tokens_dict, special_tokens=True)
161 | # model.resize_token_embeddings(len(tokenizer))
162 |
163 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
164 | model.resize_token_embeddings(len(tokenizer))
165 |
166 | if num_new_tokens > 0:
167 | input_embeddings = model.get_input_embeddings().weight.data
168 | output_embeddings = model.get_output_embeddings().weight.data
169 |
170 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
171 | dim=0, keepdim=True)
172 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
173 | dim=0, keepdim=True)
174 |
175 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
176 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
177 |
178 |
179 | def maybe_zero_3(param, ignore_status=False, name=None):
180 | from deepspeed import zero
181 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
182 | if hasattr(param, "ds_id"):
183 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
184 | if not ignore_status:
185 | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
186 | with zero.GatheredParameters([param]):
187 | param = param.data.detach().cpu().clone()
188 | else:
189 | param = param.detach().cpu().clone()
190 | return param
191 |
192 |
193 | # Borrowed from peft.utils.get_peft_model_state_dict
194 | def get_peft_state_maybe_zero_3(named_params, bias):
195 | if bias == "none":
196 | to_return = {k: t for k, t in named_params if "lora_" in k}
197 | elif bias == "all":
198 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k}
199 | elif bias == "lora_only":
200 | to_return = {}
201 | maybe_lora_bias = {}
202 | lora_bias_names = set()
203 | for k, t in named_params:
204 | if "lora_" in k:
205 | to_return[k] = t
206 | bias_name = k.split("lora_")[0] + "bias"
207 | lora_bias_names.add(bias_name)
208 | elif "bias" in k:
209 | maybe_lora_bias[k] = t
210 | for k, t in maybe_lora_bias:
211 | if bias_name in lora_bias_names:
212 | to_return[bias_name] = t
213 | else:
214 | raise NotImplementedError
215 | to_return = {k: maybe_zero_3(v, name=k) for k, v in to_return.items()}
216 | return to_return
217 |
218 |
219 | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True):
220 | to_return = {k: t for k, t in named_params if "lora_" not in k}
221 | if require_grad_only:
222 | to_return = {k: t for k, t in to_return.items() if t.requires_grad}
223 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
224 | return to_return
225 |
226 |
227 | def find_all_linear_names(model):
228 | cls = torch.nn.Linear
229 | lora_module_names = set()
230 | for name, module in model.named_modules():
231 | if isinstance(module, cls) and 'vision_model' not in name and 'mm_projector' not in name and 'vision_encoder' not in name and 'conv_final' not in name and'lm_head' not in name:
232 | lora_module_names.add(name)
233 |
234 | print(lora_module_names)
235 | return list(lora_module_names)
--------------------------------------------------------------------------------
/Vary-master/zero_config/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": true
4 | },
5 | "train_micro_batch_size_per_gpu": "auto",
6 | "zero_optimization": {
7 | "stage": 2,
8 | "overlap_comm": true,
9 | "contiguous_gradients": true,
10 | "sub_group_size": 1e9,
11 | "reduce_bucket_size": "auto"
12 | }
13 | }
--------------------------------------------------------------------------------
/assets/vary-toy-logo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Ucas-HaoranWei/Vary-toy/c6e405977aa6e88d4807a6b79f394299623fc6f3/assets/vary-toy-logo.jpg
--------------------------------------------------------------------------------