├── LICENSE
├── README.md
├── cerule
├── constants.py
├── conversation.py
├── model
│ ├── __init__.py
│ ├── builder.py
│ ├── cerule_arch.py
│ ├── language_model
│ │ ├── cerule_gemma.py
│ │ └── gemma
│ │ │ ├── __init__.py
│ │ │ ├── configuration_gemma.py
│ │ │ └── modeling_gemma.py
│ ├── multimodal_encoder
│ │ ├── builder.py
│ │ ├── clip
│ │ │ └── clip_encoder.py
│ │ ├── eva_clip
│ │ │ ├── eva_clip_encoder.py
│ │ │ ├── eva_clip_processors.py
│ │ │ └── eva_vit.py
│ │ └── siglip
│ │ │ └── siglip_encoder.py
│ └── multimodal_projector
│ │ └── builder.py
├── serve
│ ├── cli.py
│ ├── controller.py
│ ├── examples
│ │ ├── example_1.png
│ │ └── example_2.png
│ ├── gradio_web_server.py
│ ├── model_worker.py
│ └── register_worker.py
├── train
│ ├── cerule_trainer.py
│ └── train.py
└── util
│ ├── data_utils.py
│ ├── mm_utils.py
│ └── utils.py
├── examples
├── YHyRn8r.png
├── astronaut.png
├── bridge.png
├── design.jpg
├── extreme_ironing.jpg
├── google.png
├── graph.jpg
├── graph1.jpg
├── image.png
├── mario.png
└── sting.png
├── pyproject.toml
└── script
├── deepspeed
├── scripts
│ └── zero3_offload.json
├── zero2.json
└── zero3.json
├── merge_lora_weights.py
└── train
├── finetune_full.sh
├── finetune_lora.sh
└── pretrain.sh
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ---
2 | license: apache-2.0
3 | language:
4 | - en
5 | pipeline_tag: image-text-to-text
6 | ---
7 |
8 | # Cerule - A TinyMightyVisionModel
9 | ### Based on Google's - Gemma-2b + SigLIP
10 |
11 |
12 |
13 | ```
14 | ██████╗███████╗██████╗ ██╗ ██╗██╗ ███████╗
15 | ██╔════╝██╔════╝██╔══██╗██║ ██║██║ ██╔════╝
16 | ██║ █████╗ ██████╔╝██║ ██║██║ █████╗
17 | ██║ ██╔══╝ ██╔══██╗██║ ██║██║ ██╔══╝
18 | ╚██████╗███████╗██║ ██║╚██████╔╝███████╗███████╗
19 | ╚═════╝╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚══════╝
20 | ```
21 |
22 |
23 |
24 |
25 |
26 | We train and release "Cerule", a tiny yet powerful Vision Lanuage Model based on the newly released Google's [Gemma-2b](https://huggingface.co/google/gemma-2b) and Google's [SigLIP](https://huggingface.co/google/siglip-so400m-patch14-384).
27 |
28 |
29 | The training setup was `4xA100's 80GB` and took ~6 hours to pretrain and ~13 hours to finetune. We modify and adapt the training code from [Bunny](https://github.com/BAAI-DCAI/Bunny).
30 |
31 | ---
32 | | Image | Example |
33 | |-------|---------|
34 | |  | **Describe the image** The image is a playful and surreal depiction of a man in a space suit, sitting on a chair and holding a green beer bottle. The man is wearing a white space suit, complete with a helmet and gloves. His feet are clad in black and white shoes, and he is placed on a sandy surface. The background features a large, blue planet, with a moon and a star visible in the sky. |
35 | |  | **Who are the characters in the image?** The image features three characters, two of them are Mario and Luigi, and the third one is Yoshi.
**Describe the actions of the characters** The Mario and Luigi characters are holding their arms out, as if they are waving. Yoshi is standing on its own, with its arms folded. |
36 | |  | **What's funny about this image?** The image is quite humorous as it depicts a man ironing clothes on the back of a yellow taxi cab. This is not a typical sight you'd expect to see in everyday life. |
37 | ---
38 |
39 |
40 | ## Training
41 |
42 | Before running the training, you need to install the following dependencies:
43 |
44 | * Create a conda env:
45 | ```
46 | conda create -n cerule python=3.10
47 | conda activate cerule
48 | ```
49 | * Basic requirements
50 | ```
51 | pip install --upgrade pip
52 | pip install transformers
53 | pip install torch torchvision xformers --index-url https://download.pytorch.org/whl/cu118
54 | ```
55 |
56 | * Instal Apex. Please install from source, as the package on pypi is not related to this.
57 | ```
58 | pip install ninja
59 | git clone https://github.com/NVIDIA/apex
60 | cd apex
61 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
62 | # https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features
63 | ```
64 | * Install flash-attention
65 | ```
66 | pip install packaging
67 | pip install flash-attn --no-build-isolation
68 | ```
69 | * Install Cerule and other requirements
70 | ```
71 | git clone https://github.com/Tensoic-AI/Cerule
72 | cd Cerule
73 | pip install -e .
74 | ```
75 |
76 | ### Pretrain
77 |
78 | * Data preparation
79 | We use the following Dataset prepared by the amazing folks at [Beijing Academy of Artificial Intelligence](https://huggingface.co/BAAI)
80 | The dataset is available [here](https://www.modelscope.cn/datasets/BoyaWu10/Bunny-v1.0-data).
81 |
82 | Pretrain Dataset format:
83 | ```
84 | {
85 | "conversations": [
86 | {
87 | "from": "human",
88 | "value": "\nProvide a brief description of the given image."
89 | },
90 | {
91 | "from": "gpt",
92 | "value": "A set of three chrome and bubble glass table lamp bases. H.50cm - Image 4 of 10"
93 | }
94 | ],
95 | "id": "0006418798",
96 | "image": "0006418798.jpg"
97 | },
98 | ```
99 |
100 | * Run
101 |
102 | Update `--model_name_or_path` and `--vision_tower` to the paths of the LLM and vision encoder, respectively. Update `MODEL_TYPE` and `OUTPUT_DIR` accordingly.
103 |
104 | ```shell
105 | sh script/train/pretrain.sh
106 | ```
107 |
108 | ### Visual Instruction Tuning
109 |
110 | * Data preparation
111 |
112 | We also utilize Bunny-695K a modified version of [SVIT-mix-665K](https://arxiv.org/abs/2307.04087) for finetuning by BAAI.
113 | The dataset is available [here](https://www.modelscope.cn/datasets/BoyaWu10/Bunny-v1.0-data).
114 |
115 | * Run
116 |
117 | Update `--model_name_or_path` and `--vision_tower` to the paths of the LLM and vision encoder, respectively. Update `MODEL_TYPE`, `PRETRAIN_DIR` and `OUTPUT_DIR` accordingly. The global batch size is 128.
118 |
119 | ```shell
120 | # full-parameter tuning
121 | sh script/train/finetune_full.sh
122 |
123 | # LoRA tuning
124 | sh script/train/finetune_lora.sh
125 | ```
126 |
127 |
128 | ## Inference
129 | #### For a CLI based inference:
130 | ```
131 | python3 -m cerule.serve.cli \
132 | --model-path Tensoic/Cerule-v0.1 \
133 | --image-file examples/astronaut.png
134 | ```
135 |
136 | ## License
137 | Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0. This file may not be copied, modified, or distributed except according to those terms.
138 |
139 | ## Contribution
140 | Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be licensed as above, without any additional terms or conditions.
141 |
142 | ## Acknowledgements
143 | We sincerely thank the Amazing teams at Google, LLaVA, and BAAI without which this project would not have been possible!
144 |
145 | ## Star History
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
--------------------------------------------------------------------------------
/cerule/constants.py:
--------------------------------------------------------------------------------
1 | # Model Constants
2 | IGNORE_INDEX = -100
3 | IMAGE_TOKEN_INDEX = -200
4 | DEFAULT_IMAGE_TOKEN = ""
5 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
6 | LOGDIR = "gradio-logs"
7 | WORKER_HEART_BEAT_INTERVAL = 15
8 | IGNORE_INDEX = -100
9 | IMAGE_TOKEN_INDEX = -200
10 | DEFAULT_IMAGE_TOKEN = ""
11 | DEFAULT_IMAGE_PATCH_TOKEN = ""
12 | DEFAULT_IM_START_TOKEN = ""
13 | DEFAULT_IM_END_TOKEN = ""
14 | IMAGE_PLACEHOLDER = ""
--------------------------------------------------------------------------------
/cerule/conversation.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from enum import auto, Enum
3 | from typing import List, Tuple
4 | import base64
5 | from io import BytesIO
6 | from PIL import Image
7 |
8 |
9 | class SeparatorStyle(Enum):
10 | """Different separator style."""
11 | TWO = auto()
12 | PLAIN = auto()
13 | GEMMA = auto()
14 |
15 |
16 | @dataclasses.dataclass
17 | class Conversation:
18 | """A class that keeps all conversation history."""
19 | system: str
20 | roles: List[str]
21 | messages: List[List[str]]
22 | offset: int
23 | sep_style: SeparatorStyle
24 | sep: str = "###"
25 | sep2: str = None
26 | version: str = "Unknown"
27 |
28 | skip_next: bool = False
29 |
30 | def get_prompt(self):
31 | messages = self.messages
32 | if len(messages) > 0 and type(messages[0][1]) is tuple:
33 | messages = self.messages.copy()
34 | init_role, init_msg = messages[0].copy()
35 | init_msg = init_msg[0].replace("", "").strip()
36 | if 'mmtag' in self.version:
37 | messages[0] = (init_role, init_msg)
38 | messages.insert(0, (self.roles[0], ""))
39 | messages.insert(1, (self.roles[1], "Received."))
40 | else:
41 | messages[0] = (init_role, "\n" + init_msg)
42 |
43 | if self.sep_style == SeparatorStyle.TWO:
44 | seps = [self.sep, self.sep2]
45 | ret = self.system + seps[0]
46 | for i, (role, message) in enumerate(messages):
47 | if message:
48 | if type(message) is tuple:
49 | message, _, _ = message
50 | ret += role + ": " + message + seps[i % 2]
51 | else:
52 | ret += role + ":"
53 |
54 | elif self.sep_style == SeparatorStyle.PLAIN:
55 | seps = [self.sep, self.sep2]
56 | ret = self.system
57 | for i, (role, message) in enumerate(messages):
58 | if message:
59 | if type(message) is tuple:
60 | message, _, _ = message
61 | ret += message + seps[i % 2]
62 | else:
63 | ret += ""
64 | elif self.sep_style == SeparatorStyle.GEMMA:
65 | ret = ""
66 | for i, (role, message) in enumerate(messages):
67 | assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..."
68 | if message:
69 | if type(message) is tuple:
70 | message, _, _ = message
71 | ret += role + message + self.sep
72 | else:
73 | ret += role
74 | else:
75 | raise ValueError(f"Invalid style: {self.sep_style}")
76 |
77 | return ret
78 |
79 | def append_message(self, role, message):
80 | self.messages.append([role, message])
81 |
82 | def get_images(self, return_pil=False):
83 | images = []
84 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
85 | if i % 2 == 0:
86 | if type(msg) is tuple:
87 | import base64
88 | from io import BytesIO
89 | from PIL import Image
90 | msg, image, image_process_mode = msg
91 | if image_process_mode == "Pad":
92 | def expand2square(pil_img, background_color=(122, 116, 104)):
93 | width, height = pil_img.size
94 | if width == height:
95 | return pil_img
96 | elif width > height:
97 | result = Image.new(pil_img.mode, (width, width), background_color)
98 | result.paste(pil_img, (0, (width - height) // 2))
99 | return result
100 | else:
101 | result = Image.new(pil_img.mode, (height, height), background_color)
102 | result.paste(pil_img, ((height - width) // 2, 0))
103 | return result
104 |
105 | image = expand2square(image)
106 | elif image_process_mode in ["Default", "Crop"]:
107 | pass
108 | elif image_process_mode == "Resize":
109 | image = image.resize((336, 336))
110 | else:
111 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
112 | max_hw, min_hw = max(image.size), min(image.size)
113 | aspect_ratio = max_hw / min_hw
114 | max_len, min_len = 800, 400
115 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
116 | longest_edge = int(shortest_edge * aspect_ratio)
117 | W, H = image.size
118 | if longest_edge != max(image.size):
119 | if H > W:
120 | H, W = longest_edge, shortest_edge
121 | else:
122 | H, W = shortest_edge, longest_edge
123 | image = image.resize((W, H))
124 | if return_pil:
125 | images.append(image)
126 | else:
127 | buffered = BytesIO()
128 | image.save(buffered, format="PNG")
129 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
130 | images.append(img_b64_str)
131 | return images
132 |
133 | def to_gradio_chatbot(self):
134 | ret = []
135 | for i, (role, msg) in enumerate(self.messages[self.offset:]):
136 | if i % 2 == 0:
137 | if type(msg) is tuple:
138 | import base64
139 | from io import BytesIO
140 | msg, image, image_process_mode = msg
141 | max_hw, min_hw = max(image.size), min(image.size)
142 | aspect_ratio = max_hw / min_hw
143 | max_len, min_len = 800, 400
144 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
145 | longest_edge = int(shortest_edge * aspect_ratio)
146 | W, H = image.size
147 | if H > W:
148 | H, W = longest_edge, shortest_edge
149 | else:
150 | H, W = shortest_edge, longest_edge
151 | image = image.resize((W, H))
152 | buffered = BytesIO()
153 | image.save(buffered, format="JPEG")
154 | img_b64_str = base64.b64encode(buffered.getvalue()).decode()
155 | img_str = f''
156 | msg = img_str + msg.replace('', '').strip()
157 | ret.append([msg, None])
158 | else:
159 | ret.append([msg, None])
160 | else:
161 | ret[-1][-1] = msg
162 | return ret
163 |
164 | def copy(self):
165 | return Conversation(
166 | system=self.system,
167 | roles=self.roles,
168 | messages=[[x, y] for x, y in self.messages],
169 | offset=self.offset,
170 | sep_style=self.sep_style,
171 | sep=self.sep,
172 | sep2=self.sep2,
173 | version=self.version)
174 |
175 | def dict(self):
176 | if len(self.get_images()) > 0:
177 | return {
178 | "system": self.system,
179 | "roles": self.roles,
180 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
181 | "offset": self.offset,
182 | "sep": self.sep,
183 | "sep2": self.sep2,
184 | }
185 | return {
186 | "system": self.system,
187 | "roles": self.roles,
188 | "messages": self.messages,
189 | "offset": self.offset,
190 | "sep": self.sep,
191 | "sep2": self.sep2,
192 | }
193 |
194 | conv_plain = Conversation(
195 | system="",
196 | roles=("", ""),
197 | messages=(
198 | ),
199 | offset=0,
200 | sep_style=SeparatorStyle.PLAIN,
201 | sep="\n",
202 | )
203 |
204 | conv_gemma_instruct = Conversation(
205 | system="",
206 | roles=("user\n", "model\n"),
207 | version="gemma",
208 | messages=(),
209 | offset=0,
210 | sep_style=SeparatorStyle.GEMMA,
211 | sep="\n"
212 | )
213 |
214 | default_conversation = conv_gemma_instruct
215 | conv_templates = {
216 | "default": conv_gemma_instruct,
217 | "plain": conv_plain,
218 | "gemma_instruct": conv_gemma_instruct,
219 | }
220 |
221 | if __name__ == "__main__":
222 | print(default_conversation.get_prompt())
223 |
--------------------------------------------------------------------------------
/cerule/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .language_model.cerule_gemma import CeruleGemmaForCausalLM, CeruleGemmaConfig
2 |
--------------------------------------------------------------------------------
/cerule/model/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import warnings
3 |
4 | warnings.filterwarnings("ignore")
5 |
6 | from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig
7 | import torch
8 | from cerule.model import *
9 |
10 | # can add more models. just load, as done below for gemma
11 |
12 | def load_pretrained_model(model_path, model_base, model_name, model_type, load_8bit=False, load_4bit=False,
13 | device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
14 | if model_type not in {'gemma'}:
15 | raise ValueError(f"Unknown Model Type {model_type}")
16 |
17 | kwargs = {"device_map": device_map, **kwargs}
18 |
19 | if device != "cuda":
20 | kwargs['device_map'] = {"": device}
21 |
22 | if load_8bit:
23 | kwargs['load_in_8bit'] = True
24 | elif load_4bit:
25 | kwargs['load_in_4bit'] = True
26 | kwargs['quantization_config'] = BitsAndBytesConfig(
27 | load_in_4bit=True,
28 | bnb_4bit_compute_dtype=torch.float16,
29 | bnb_4bit_use_double_quant=True,
30 | bnb_4bit_quant_type='nf4'
31 | )
32 | else:
33 | kwargs['torch_dtype'] = torch.float16
34 |
35 | # Load Cerule model
36 | if 'lora' in model_name.lower() and model_base is None:
37 | warnings.warn(
38 | 'There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument.')
39 | if 'lora' in model_name.lower() and model_base is not None:
40 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
41 |
42 | print('Loading Cerule from base model...')
43 | if model_type == 'gemma':
44 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, trust_remote_code=True)
45 | model = CeruleGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
46 | config=lora_cfg_pretrained, **kwargs)
47 |
48 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
49 | if model.lm_head.weight.shape[0] != token_num:
50 | model.lm_head.weight = torch.nn.Parameter(
51 | torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
52 | model.model.embed_tokens.weight = torch.nn.Parameter(
53 | torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
54 |
55 | print('Loading additional Cerule weights...')
56 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
57 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
58 | else:
59 | # this is probably from HF Hub
60 | from huggingface_hub import hf_hub_download
61 | def load_from_hf(repo_id, filename, subfolder=None):
62 | cache_file = hf_hub_download(
63 | repo_id=repo_id,
64 | filename=filename,
65 | subfolder=subfolder)
66 | return torch.load(cache_file, map_location='cpu')
67 |
68 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
69 |
70 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in
71 | non_lora_trainables.items()}
72 | if any(k.startswith('model.model.') for k in non_lora_trainables):
73 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in
74 | non_lora_trainables.items()}
75 | model.load_state_dict(non_lora_trainables, strict=False)
76 |
77 | from peft import PeftModel
78 | print('Loading LoRA weights...')
79 | model = PeftModel.from_pretrained(model, model_path)
80 | print('Merging LoRA weights...')
81 | model = model.merge_and_unload()
82 | print('Model is loaded...')
83 | elif model_base is not None:
84 | print('Loading Cerule from base model...')
85 |
86 | cfg_pretrained = AutoConfig.from_pretrained(model_path)
87 | if model_type == 'gemma':
88 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, trust_remote_code=True)
89 | model = CeruleGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True,
90 | config=cfg_pretrained, **kwargs)
91 |
92 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
93 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
94 | model.load_state_dict(mm_projector_weights, strict=False)
95 | else:
96 | if model_type == 'gemma':
97 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
98 | model = CeruleGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
99 |
100 | model.resize_token_embeddings(len(tokenizer))
101 |
102 | vision_tower = model.get_vision_tower()
103 | if not vision_tower.is_loaded:
104 | vision_tower.load_model()
105 | vision_tower.to(device=device, dtype=torch.float16)
106 | image_processor = vision_tower.image_processor
107 |
108 | if hasattr(model.config, "max_sequence_length"):
109 | context_len = model.config.max_sequence_length
110 | else:
111 | context_len = 2048
112 |
113 | if model.config.pad_token_id is None:
114 | model.config.pad_token_id = model.config.eos_token_id
115 |
116 | return tokenizer, model, image_processor, context_len
117 |
--------------------------------------------------------------------------------
/cerule/model/cerule_arch.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import torch
4 |
5 | from .multimodal_encoder.builder import build_vision_tower
6 | from .multimodal_projector.builder import build_vision_projector
7 |
8 | from cerule.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX
9 |
10 |
11 | class CeruleMetaModel:
12 |
13 | def __init__(self, config):
14 | super(CeruleMetaModel, self).__init__(config)
15 |
16 | if hasattr(config, "mm_vision_tower"):
17 | self.vision_tower = build_vision_tower(config, delay_load=True)
18 | self.mm_projector = build_vision_projector(config)
19 |
20 | def get_vision_tower(self):
21 | vision_tower = getattr(self, 'vision_tower', None)
22 | if type(vision_tower) is list:
23 | vision_tower = vision_tower[0]
24 | return vision_tower
25 |
26 | def initialize_vision_modules(self, model_args):
27 | vision_tower = model_args.vision_tower
28 |
29 | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter
30 |
31 | self.config.mm_vision_tower = vision_tower
32 |
33 | if self.get_vision_tower() is None:
34 | vision_tower = build_vision_tower(model_args)
35 | self.vision_tower = vision_tower
36 | else:
37 | vision_tower = self.vision_tower
38 | vision_tower.load_model()
39 |
40 | self.config.use_mm_proj = True
41 | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type')
42 | self.config.mm_hidden_size = vision_tower.hidden_size
43 |
44 | if getattr(self, 'mm_projector', None) is None:
45 | self.mm_projector = build_vision_projector(self.config)
46 | else:
47 | # In case it is frozen by LoRA
48 | for p in self.mm_projector.parameters():
49 | p.requires_grad = True
50 |
51 | if pretrain_mm_mlp_adapter is not None:
52 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
53 |
54 | def get_w(weights, keyword):
55 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
56 |
57 | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector'))
58 |
59 |
60 | class CeruleMetaForCausalLM(ABC):
61 |
62 | @abstractmethod
63 | def get_model(self):
64 | pass
65 |
66 | def get_vision_tower(self):
67 | return self.get_model().get_vision_tower()
68 |
69 | def encode_images(self, images):
70 | image_features = self.get_model().get_vision_tower()(images)
71 | image_features = self.get_model().mm_projector(image_features)
72 | return image_features
73 |
74 | def prepare_inputs_labels_for_multimodal(
75 | self, input_ids, position_ids, attention_mask, past_key_values, labels, images
76 | ):
77 | vision_tower = self.get_vision_tower()
78 | if vision_tower is None or images is None or input_ids.shape[1] == 1:
79 | if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[
80 | 1] == 1:
81 | target_shape = past_key_values[-1][-1].shape[-2] + 1
82 | attention_mask = torch.cat((attention_mask, torch.ones(
83 | (attention_mask.shape[0], target_shape - attention_mask.shape[1]),
84 | dtype=attention_mask.dtype,
85 | device=attention_mask.device
86 | )), dim=1)
87 | position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1
88 | return input_ids, position_ids, attention_mask, past_key_values, None, labels
89 |
90 | if type(images) is list or images.ndim == 5:
91 | concat_images = torch.cat([image for image in images], dim=0)
92 | image_features = self.encode_images(concat_images)
93 | split_sizes = [image.shape[0] for image in images]
94 | image_features = torch.split(image_features, split_sizes, dim=0)
95 | image_features = [x.flatten(0, 1).to(self.device) for x in image_features]
96 | else:
97 | image_features = self.encode_images(images).to(self.device)
98 |
99 | # Let's just add dummy tensors if they do not exist,
100 | # it is a headache to deal with None all the time.
101 | # But it is not ideal, and if you have a better idea,
102 | # please open an issue / submit a PR, thanks.
103 | _labels = labels
104 | _position_ids = position_ids
105 | _attention_mask = attention_mask
106 | if attention_mask is None:
107 | attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
108 | else:
109 | attention_mask = attention_mask.bool()
110 | if position_ids is None:
111 | position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
112 | if labels is None:
113 | labels = torch.full_like(input_ids, IGNORE_INDEX)
114 |
115 | # remove the padding using attention_mask -- TODO: double check
116 | input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in
117 | zip(input_ids, attention_mask)]
118 | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)]
119 |
120 | new_input_embeds = []
121 | new_labels = []
122 | cur_image_idx = 0
123 | for batch_idx, cur_input_ids in enumerate(input_ids):
124 | num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
125 | if num_images == 0:
126 | cur_image_features = image_features[cur_image_idx]
127 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
128 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
129 | new_input_embeds.append(cur_input_embeds)
130 | new_labels.append(labels[batch_idx])
131 | cur_image_idx += 1
132 | continue
133 |
134 | image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [
135 | cur_input_ids.shape[0]]
136 | cur_input_ids_noim = []
137 | cur_labels = labels[batch_idx]
138 | cur_labels_noim = []
139 | for i in range(len(image_token_indices) - 1):
140 | cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
141 | cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]])
142 | split_sizes = [x.shape[0] for x in cur_labels_noim]
143 | cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim))
144 | cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0)
145 | cur_new_input_embeds = []
146 | cur_new_labels = []
147 |
148 | for i in range(num_images + 1):
149 | cur_new_input_embeds.append(cur_input_embeds_no_im[i])
150 | cur_new_labels.append(cur_labels_noim[i])
151 | if i < num_images:
152 | cur_image_features = image_features[cur_image_idx]
153 | cur_image_idx += 1
154 | cur_new_input_embeds.append(cur_image_features)
155 | cur_new_labels.append(
156 | torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device,
157 | dtype=cur_labels.dtype))
158 |
159 | cur_new_input_embeds = torch.cat(cur_new_input_embeds)
160 | cur_new_labels = torch.cat(cur_new_labels)
161 |
162 | new_input_embeds.append(cur_new_input_embeds)
163 | new_labels.append(cur_new_labels)
164 |
165 | # Truncate sequences to max length as image embeddings can make the sequence longer
166 | tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None)
167 | if tokenizer_model_max_length is not None:
168 | new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds]
169 | new_labels = [x[:tokenizer_model_max_length] for x in new_labels]
170 |
171 | # Combine them
172 | max_len = max(x.shape[0] for x in new_input_embeds)
173 | batch_size = len(new_input_embeds)
174 |
175 | new_input_embeds_padded = []
176 | new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype,
177 | device=new_labels[0].device)
178 | attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
179 | position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
180 |
181 | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)):
182 | cur_len = cur_new_embed.shape[0]
183 | if getattr(self.config, 'tokenizer_padding_side', 'right') == "left":
184 | new_input_embeds_padded.append(torch.cat((
185 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
186 | device=cur_new_embed.device),
187 | cur_new_embed
188 | ), dim=0))
189 | if cur_len > 0:
190 | new_labels_padded[i, -cur_len:] = cur_new_labels
191 | attention_mask[i, -cur_len:] = True
192 | position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype,
193 | device=position_ids.device)
194 | else:
195 | new_input_embeds_padded.append(torch.cat((
196 | cur_new_embed,
197 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype,
198 | device=cur_new_embed.device)
199 | ), dim=0))
200 | if cur_len > 0:
201 | new_labels_padded[i, :cur_len] = cur_new_labels
202 | attention_mask[i, :cur_len] = True
203 | position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype,
204 | device=position_ids.device)
205 |
206 | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0)
207 |
208 | if _labels is None:
209 | new_labels = None
210 | else:
211 | new_labels = new_labels_padded
212 |
213 | if _attention_mask is None:
214 | attention_mask = None
215 | else:
216 | attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
217 |
218 | if _position_ids is None:
219 | position_ids = None
220 |
221 | return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels
222 |
--------------------------------------------------------------------------------
/cerule/model/language_model/cerule_gemma.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 |
3 | import torch
4 | import torch.nn as nn
5 | from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM
6 |
7 | from transformers.modeling_outputs import CausalLMOutputWithPast
8 |
9 | from ..cerule_arch import CeruleMetaModel, CeruleMetaForCausalLM
10 |
11 |
12 | class CeruleGemmaConfig(GemmaConfig):
13 | model_type = "cerule-gemma"
14 |
15 |
16 | class CeruleGemmaModel(CeruleMetaModel, GemmaModel):
17 | config_class = CeruleGemmaConfig
18 |
19 | def __init__(self, config: GemmaConfig):
20 | super(CeruleGemmaModel, self).__init__(config)
21 |
22 |
23 | class CeruleGemmaForCausalLM(GemmaForCausalLM, CeruleMetaForCausalLM):
24 | config_class = CeruleGemmaConfig
25 |
26 | def __init__(self, config):
27 | super(GemmaForCausalLM, self).__init__(config)
28 | self.model = CeruleGemmaModel(config)
29 | self.vocab_size = config.vocab_size
30 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
31 |
32 | # Initialize weights and apply final processing
33 | self.post_init()
34 |
35 | def get_model(self):
36 | return self.model
37 |
38 | def forward(
39 | self,
40 | input_ids: torch.LongTensor = None,
41 | attention_mask: Optional[torch.Tensor] = None,
42 | position_ids: Optional[torch.LongTensor] = None,
43 | past_key_values: Optional[List[torch.FloatTensor]] = None,
44 | inputs_embeds: Optional[torch.FloatTensor] = None,
45 | labels: Optional[torch.LongTensor] = None,
46 | use_cache: Optional[bool] = None,
47 | output_attentions: Optional[bool] = None,
48 | output_hidden_states: Optional[bool] = None,
49 | images: Optional[torch.FloatTensor] = None,
50 | return_dict: Optional[bool] = None,
51 | cache_position=None,
52 | ) -> Union[Tuple, CausalLMOutputWithPast]:
53 |
54 | if inputs_embeds is None:
55 | (
56 | input_ids,
57 | position_ids,
58 | attention_mask,
59 | past_key_values,
60 | inputs_embeds,
61 | labels
62 | ) = self.prepare_inputs_labels_for_multimodal(
63 | input_ids,
64 | position_ids,
65 | attention_mask,
66 | past_key_values,
67 | labels,
68 | images
69 | )
70 |
71 | return super().forward(
72 | input_ids=input_ids,
73 | attention_mask=attention_mask,
74 | position_ids=position_ids,
75 | past_key_values=past_key_values,
76 | inputs_embeds=inputs_embeds,
77 | labels=labels,
78 | use_cache=use_cache,
79 | output_attentions=output_attentions,
80 | output_hidden_states=output_hidden_states,
81 | return_dict=return_dict
82 | )
83 |
84 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None,
85 | **kwargs):
86 | images = kwargs.pop("images", None)
87 |
88 | _inputs = super().prepare_inputs_for_generation(
89 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask,
90 | **kwargs
91 | )
92 |
93 | if images is not None:
94 | _inputs['images'] = images
95 | return _inputs
96 |
97 |
98 | AutoConfig.register("cerule-gemma", CeruleGemmaConfig)
99 | AutoModelForCausalLM.register(CeruleGemmaConfig, CeruleGemmaForCausalLM)
100 |
--------------------------------------------------------------------------------
/cerule/model/language_model/gemma/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import TYPE_CHECKING
15 |
16 | from transformers.utils import (
17 | OptionalDependencyNotAvailable,
18 | _LazyModule,
19 | is_flax_available,
20 | is_sentencepiece_available,
21 | is_tokenizers_available,
22 | is_torch_available,
23 | )
24 |
25 |
26 | _import_structure = {
27 | "configuration_gemma": ["GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "GemmaConfig"],
28 | }
29 |
30 | try:
31 | if not is_sentencepiece_available():
32 | raise OptionalDependencyNotAvailable()
33 | except OptionalDependencyNotAvailable:
34 | pass
35 | else:
36 | _import_structure["tokenization_gemma"] = ["GemmaTokenizer"]
37 |
38 | try:
39 | if not is_tokenizers_available():
40 | raise OptionalDependencyNotAvailable()
41 | except OptionalDependencyNotAvailable:
42 | pass
43 | else:
44 | _import_structure["tokenization_gemma_fast"] = ["GemmaTokenizerFast"]
45 |
46 |
47 | try:
48 | if not is_torch_available():
49 | raise OptionalDependencyNotAvailable()
50 | except OptionalDependencyNotAvailable:
51 | pass
52 | else:
53 | _import_structure["modeling_gemma"] = [
54 | "GemmaForCausalLM",
55 | "GemmaModel",
56 | "GemmaPreTrainedModel",
57 | "GemmaForSequenceClassification",
58 | ]
59 |
60 | try:
61 | if not is_flax_available():
62 | raise OptionalDependencyNotAvailable()
63 | except OptionalDependencyNotAvailable:
64 | pass
65 | else:
66 | _import_structure["modeling_flax_gemma"] = [
67 | "FlaxGemmaForCausalLM",
68 | "FlaxGemmaModel",
69 | "FlaxGemmaPreTrainedModel",
70 | ]
71 |
72 |
73 | if TYPE_CHECKING:
74 | from .configuration_gemma import GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP, GemmaConfig
75 |
76 | try:
77 | if not is_sentencepiece_available():
78 | raise OptionalDependencyNotAvailable()
79 | except OptionalDependencyNotAvailable:
80 | pass
81 | else:
82 | from .tokenization_gemma import GemmaTokenizer
83 |
84 | try:
85 | if not is_tokenizers_available():
86 | raise OptionalDependencyNotAvailable()
87 | except OptionalDependencyNotAvailable:
88 | pass
89 | else:
90 | from .tokenization_gemma_fast import GemmaTokenizerFast
91 |
92 | try:
93 | if not is_torch_available():
94 | raise OptionalDependencyNotAvailable()
95 | except OptionalDependencyNotAvailable:
96 | pass
97 | else:
98 | from .modeling_gemma import (
99 | GemmaForCausalLM,
100 | GemmaForSequenceClassification,
101 | GemmaModel,
102 | GemmaPreTrainedModel,
103 | )
104 |
105 | try:
106 | if not is_flax_available():
107 | raise OptionalDependencyNotAvailable()
108 | except OptionalDependencyNotAvailable:
109 | pass
110 | else:
111 | from .modeling_flax_gemma import (
112 | FlaxGemmaForCausalLM,
113 | FlaxGemmaModel,
114 | FlaxGemmaPreTrainedModel,
115 | )
116 |
117 |
118 | else:
119 | import sys
120 |
121 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
--------------------------------------------------------------------------------
/cerule/model/language_model/gemma/configuration_gemma.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ Gemma model configuration"""
16 |
17 | from transformers.configuration_utils import PretrainedConfig
18 | from transformers.utils import logging
19 |
20 |
21 | logger = logging.get_logger(__name__)
22 |
23 | GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
24 |
25 |
26 | class GemmaConfig(PretrainedConfig):
27 | r"""
28 | This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma
29 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
30 | defaults will yield a similar configuration to that of the Gemma-7B.
31 |
32 | e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b)
33 |
34 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
35 | documentation from [`PretrainedConfig`] for more information.
36 |
37 |
38 | Args:
39 | vocab_size (`int`, *optional*, defaults to 256000):
40 | Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the
41 | `inputs_ids` passed when calling [`GemmaModel`]
42 | hidden_size (`int`, *optional*, defaults to 3072):
43 | Dimension of the hidden representations.
44 | intermediate_size (`int`, *optional*, defaults to 24576):
45 | Dimension of the MLP representations.
46 | num_hidden_layers (`int`, *optional*, defaults to 28):
47 | Number of hidden layers in the Transformer decoder.
48 | num_attention_heads (`int`, *optional*, defaults to 16):
49 | Number of attention heads for each attention layer in the Transformer decoder.
50 | num_key_value_heads (`int`, *optional*, defaults to 16):
51 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If
52 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
53 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
54 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
55 | by meanpooling all the original heads within that group. For more details checkout [this
56 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
57 | `num_attention_heads`.
58 | head_dim (`int`, *optional*, defaults to 256):
59 | The attention head dimension.
60 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
61 | The non-linear activation function (function or string) in the decoder.
62 | max_position_embeddings (`int`, *optional*, defaults to 8192):
63 | The maximum sequence length that this model might ever be used with.
64 | initializer_range (`float`, *optional*, defaults to 0.02):
65 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
66 | rms_norm_eps (`float`, *optional*, defaults to 1e-06):
67 | The epsilon used by the rms normalization layers.
68 | use_cache (`bool`, *optional*, defaults to `True`):
69 | Whether or not the model should return the last key/values attentions (not used by all models). Only
70 | relevant if `config.is_decoder=True`.
71 | pad_token_id (`int`, *optional*, defaults to 0):
72 | Padding token id.
73 | eos_token_id (`int`, *optional*, defaults to 1):
74 | End of stream token id.
75 | bos_token_id (`int`, *optional*, defaults to 2):
76 | Beginning of stream token id.
77 | tie_word_embeddings (`bool`, *optional*, defaults to `True`):
78 | Whether to tie weight embeddings
79 | rope_theta (`float`, *optional*, defaults to 10000.0):
80 | The base period of the RoPE embeddings.
81 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
82 | Whether to use a bias in the query, key, value and output projection layers during self-attention.
83 | attention_dropout (`float`, *optional*, defaults to 0.0):
84 | The dropout ratio for the attention probabilities.
85 |
86 | ```python
87 | >>> from transformers import GemmaModel, GemmaConfig
88 |
89 | >>> # Initializing a Gemma gemma-7b style configuration
90 | >>> configuration = GemmaConfig()
91 |
92 | >>> # Initializing a model from the gemma-7b style configuration
93 | >>> model = GemmaModel(configuration)
94 |
95 | >>> # Accessing the model configuration
96 | >>> configuration = model.config
97 | ```"""
98 |
99 | model_type = "gemma"
100 | keys_to_ignore_at_inference = ["past_key_values"]
101 |
102 | def __init__(
103 | self,
104 | vocab_size=256000,
105 | hidden_size=3072,
106 | intermediate_size=24576,
107 | num_hidden_layers=28,
108 | num_attention_heads=16,
109 | num_key_value_heads=16,
110 | head_dim=256,
111 | hidden_act="gelu",
112 | max_position_embeddings=8192,
113 | initializer_range=0.02,
114 | rms_norm_eps=1e-6,
115 | use_cache=True,
116 | pad_token_id=0,
117 | eos_token_id=1,
118 | bos_token_id=2,
119 | tie_word_embeddings=True,
120 | rope_theta=10000.0,
121 | attention_bias=False,
122 | attention_dropout=0.0,
123 | **kwargs,
124 | ):
125 | self.vocab_size = vocab_size
126 | self.max_position_embeddings = max_position_embeddings
127 | self.hidden_size = hidden_size
128 | self.intermediate_size = intermediate_size
129 | self.num_hidden_layers = num_hidden_layers
130 | self.num_attention_heads = num_attention_heads
131 | self.head_dim = head_dim
132 | self.num_key_value_heads = num_key_value_heads
133 | self.hidden_act = hidden_act
134 | self.initializer_range = initializer_range
135 | self.rms_norm_eps = rms_norm_eps
136 | self.use_cache = use_cache
137 | self.rope_theta = rope_theta
138 | self.attention_bias = attention_bias
139 | self.attention_dropout = attention_dropout
140 |
141 | super().__init__(
142 | pad_token_id=pad_token_id,
143 | bos_token_id=bos_token_id,
144 | eos_token_id=eos_token_id,
145 | tie_word_embeddings=tie_word_embeddings,
146 | **kwargs,
147 | )
--------------------------------------------------------------------------------
/cerule/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .eva_clip.eva_clip_encoder import EvaClipVisionTower
3 | from .siglip.siglip_encoder import SigLipVisionTower
4 | from .clip.clip_encoder import CLIPVisionTower
5 |
6 |
7 | def build_vision_tower(vision_tower_cfg, **kwargs):
8 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None))
9 |
10 | if 'sig' in vision_tower.lower():
11 | return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs)
12 |
13 | elif 'eva' in vision_tower.lower():
14 | return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
15 |
16 | elif 'clip' in vision_tower.lower():
17 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
18 |
19 | else:
20 | raise ValueError(f'Unknown vision tower: {vision_tower}')
21 |
--------------------------------------------------------------------------------
/cerule/model/multimodal_encoder/clip/clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5 |
6 |
7 | class CLIPVisionTower(nn.Module):
8 | def __init__(self, vision_tower, args, delay_load=False):
9 | super().__init__()
10 |
11 | self.is_loaded = False
12 |
13 | self.vision_tower_name = vision_tower
14 | self.select_layer = -2
15 |
16 | if not delay_load:
17 | self.load_model()
18 | else:
19 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
20 |
21 | def load_model(self):
22 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
23 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
24 | self.vision_tower.requires_grad_(False)
25 |
26 | self.is_loaded = True
27 |
28 | def feature_select(self, image_forward_outs):
29 | image_features = image_forward_outs.hidden_states[self.select_layer]
30 |
31 | image_features = image_features[:, 1:]
32 |
33 | return image_features
34 |
35 | @torch.no_grad()
36 | def forward(self, images):
37 | if type(images) is list:
38 | image_features = []
39 | for image in images:
40 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
41 | output_hidden_states=True)
42 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
43 | image_features.append(image_feature)
44 | else:
45 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
46 | output_hidden_states=True)
47 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
48 |
49 | return image_features
50 |
51 | @property
52 | def dummy_feature(self):
53 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
54 |
55 | @property
56 | def dtype(self):
57 | return self.vision_tower.dtype
58 |
59 | @property
60 | def device(self):
61 | return self.vision_tower.device
62 |
63 | @property
64 | def config(self):
65 | if self.is_loaded:
66 | return self.vision_tower.config
67 | else:
68 | return self.cfg_only
69 |
70 | @property
71 | def hidden_size(self):
72 | return self.config.hidden_size
73 |
74 | @property
75 | def num_patches(self):
76 | return (self.config.image_size // self.config.patch_size) ** 2
77 |
--------------------------------------------------------------------------------
/cerule/model/multimodal_encoder/eva_clip/eva_clip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .eva_clip_processors import EvaClipImageTrainProcessor
5 | from .eva_vit import Eva2LargePlusEncoder
6 |
7 |
8 | class EvaClipVisionTower(nn.Module):
9 | def __init__(self, vision_tower, args, delay_load=False):
10 | super().__init__()
11 |
12 | self.is_loaded = False
13 |
14 | self.vision_tower_path = vision_tower
15 | self.config = VisionTowerConfig()
16 |
17 | if not delay_load:
18 | self.load_model()
19 | else:
20 | self.cfg_only = self.config
21 |
22 | def load_model(self):
23 | self.image_processor = EvaClipImageTrainProcessor(self.config.image_size)
24 | self.vision_tower = Eva2LargePlusEncoder(self.vision_tower_path)
25 | self.vision_tower.requires_grad_(False)
26 |
27 | self.is_loaded = True
28 |
29 | @torch.no_grad()
30 | def forward(self, images):
31 | if type(images) is list:
32 | image_features = []
33 | for image in images:
34 | image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to(
35 | image.dtype)
36 | image_features.append(image_feature)
37 | else:
38 | image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype)
39 |
40 | return image_features
41 |
42 | @property
43 | def dtype(self):
44 | return self.vision_tower.dtype
45 |
46 | @property
47 | def device(self):
48 | return self.vision_tower.device
49 |
50 | @property
51 | def hidden_size(self):
52 | return self.config.hidden_size
53 |
54 | @property
55 | def num_patches(self):
56 | return (self.config.image_size // self.config.patch_size) ** 2
57 |
58 |
59 | class VisionTowerConfig():
60 | def __init__(self):
61 | self.image_size = 336
62 | self.patch_size = 14
63 | self.hidden_size = 1024
64 |
--------------------------------------------------------------------------------
/cerule/model/multimodal_encoder/eva_clip/eva_clip_processors.py:
--------------------------------------------------------------------------------
1 | '''
2 | # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP
3 | '''
4 |
5 | from torchvision import transforms
6 | from torchvision.transforms.functional import InterpolationMode
7 | from transformers.image_processing_utils import BatchFeature
8 | from PIL import Image
9 | from transformers.image_transforms import convert_to_rgb
10 |
11 |
12 | class BaseProcessor:
13 | def __init__(self):
14 | self.transform = lambda x: x
15 | return
16 |
17 | def __call__(self, item):
18 | return self.transform(item)
19 |
20 |
21 | class EvaClipImageBaseProcessor(BaseProcessor):
22 | def __init__(self, mean=None, std=None):
23 | self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean
24 | self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std
25 |
26 | self.normalize = transforms.Normalize(self.mean, self.std)
27 |
28 | @property
29 | def image_mean(self):
30 | return self.mean
31 |
32 |
33 | class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor):
34 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0):
35 | super().__init__(mean=mean, std=std)
36 |
37 | self.transform = transforms.Compose(
38 | [
39 | convert_to_rgb,
40 | transforms.Resize(
41 | image_size,
42 | interpolation=InterpolationMode.BICUBIC,
43 | ),
44 | transforms.CenterCrop(image_size),
45 | transforms.ToTensor(),
46 | self.normalize,
47 | ]
48 | )
49 |
50 | self.image_size = image_size
51 |
52 | def preprocess(self, images, return_tensors):
53 | if isinstance(images, Image.Image):
54 | images = [images]
55 | else:
56 | assert isinstance(images, list)
57 |
58 | transformed_images = [self.transform(image).numpy() for image in images]
59 | data = {"pixel_values": transformed_images}
60 |
61 | return BatchFeature(data=data, tensor_type=return_tensors)
62 |
63 | def __call__(self, item):
64 | return self.transform(item)
65 |
66 | @property
67 | def crop_size(self):
68 | return {'height': self.image_size, 'width': self.image_size}
69 |
--------------------------------------------------------------------------------
/cerule/model/multimodal_encoder/siglip/siglip_encoder.py:
--------------------------------------------------------------------------------
1 | '''
2 | # Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py
3 | '''
4 |
5 | from typing import Optional, Tuple, Union, Dict
6 | from dataclasses import dataclass
7 | from functools import partial, reduce
8 | from PIL import Image
9 | import torch
10 | import torch.utils.checkpoint
11 | from torch import nn
12 | import os
13 | from transformers.image_processing_utils import BatchFeature, get_size_dict
14 | from transformers.image_transforms import (convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, )
15 | from transformers.image_utils import (ChannelDimension, PILImageResampling, to_numpy_array, )
16 | from transformers.activations import ACT2FN
17 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
18 | from transformers.modeling_utils import PreTrainedModel
19 | from transformers import PretrainedConfig
20 | from transformers.utils import ModelOutput
21 |
22 |
23 | class SigLipImageProcessor:
24 | def __init__(self,
25 | image_mean=(0.5, 0.5, 0.5),
26 | image_std=(0.5, 0.5, 0.5),
27 | size=(384, 384),
28 | crop_size: Dict[str, int] = None,
29 | resample=PILImageResampling.BICUBIC,
30 | rescale_factor=1 / 255,
31 | data_format=ChannelDimension.FIRST):
32 | crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384}
33 | crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size")
34 |
35 | self.image_mean = image_mean
36 | self.image_std = image_std
37 | self.size = size
38 | self.resample = resample
39 | self.rescale_factor = rescale_factor
40 | self.data_format = data_format
41 | self.crop_size = crop_size
42 |
43 | def preprocess(self, images, return_tensors):
44 | if isinstance(images, Image.Image):
45 | images = [images]
46 | else:
47 | assert isinstance(images, list)
48 |
49 | transforms = [
50 | convert_to_rgb,
51 | to_numpy_array,
52 | partial(resize, size=self.size, resample=self.resample, data_format=self.data_format),
53 | partial(rescale, scale=self.rescale_factor, data_format=self.data_format),
54 | partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format),
55 | partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format),
56 | ]
57 |
58 | images = reduce(lambda x, f: [*map(f, x)], transforms, images)
59 | data = {"pixel_values": images}
60 |
61 | return BatchFeature(data=data, tensor_type=return_tensors)
62 |
63 |
64 | class SigLipVisionConfig(PretrainedConfig):
65 | model_type = "siglip_vision_model"
66 |
67 | def __init__(
68 | self,
69 | hidden_size=1152,
70 | image_mean=(0.5, 0.5, 0.5),
71 | intermediate_size=4304,
72 | num_hidden_layers=27,
73 | num_attention_heads=16,
74 | num_channels=3,
75 | image_size=384,
76 | patch_size=14,
77 | hidden_act="gelu_pytorch_tanh",
78 | layer_norm_eps=1e-6,
79 | attention_dropout=0.0,
80 | **kwargs,
81 | ):
82 | super().__init__(**kwargs)
83 |
84 | self.hidden_size = hidden_size
85 | self.intermediate_size = intermediate_size
86 | self.num_hidden_layers = num_hidden_layers
87 | self.num_attention_heads = num_attention_heads
88 | self.num_channels = num_channels
89 | self.patch_size = patch_size
90 | self.image_size = image_size
91 | self.attention_dropout = attention_dropout
92 | self.layer_norm_eps = layer_norm_eps
93 | self.hidden_act = hidden_act
94 | self.image_mean = image_mean
95 |
96 | @classmethod
97 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
98 | cls._set_token_in_kwargs(kwargs)
99 |
100 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
101 |
102 | # get the vision config dict if we are loading from SigLipConfig
103 | if config_dict.get("model_type") == "siglip":
104 | config_dict = config_dict["vision_config"]
105 |
106 | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
107 | logger.warning(
108 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
109 | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
110 | )
111 |
112 | return cls.from_dict(config_dict, **kwargs)
113 |
114 |
115 | @dataclass
116 | # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip
117 | class SigLipVisionModelOutput(ModelOutput):
118 | """
119 | Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states.
120 |
121 | Args:
122 | image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`):
123 | The image embeddings obtained by applying the projection layer to the pooler_output.
124 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
125 | Sequence of hidden-states at the output of the last layer of the model.
126 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
127 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
128 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
129 |
130 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
131 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
132 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
133 | sequence_length)`.
134 |
135 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
136 | heads.
137 | """
138 |
139 | image_embeds: Optional[torch.FloatTensor] = None
140 | last_hidden_state: torch.FloatTensor = None
141 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
142 | attentions: Optional[Tuple[torch.FloatTensor]] = None
143 |
144 |
145 | class SigLipVisionEmbeddings(nn.Module):
146 | def __init__(self, config: SigLipVisionConfig):
147 | super().__init__()
148 | self.config = config
149 | self.embed_dim = config.hidden_size
150 | self.image_size = config.image_size
151 | self.patch_size = config.patch_size
152 |
153 | self.patch_embedding = nn.Conv2d(
154 | in_channels=config.num_channels,
155 | out_channels=self.embed_dim,
156 | kernel_size=self.patch_size,
157 | stride=self.patch_size,
158 | padding="valid",
159 | )
160 |
161 | self.num_patches = (self.image_size // self.patch_size) ** 2
162 | self.num_positions = self.num_patches
163 | self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
164 | self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
165 |
166 | def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
167 | patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
168 | embeddings = patch_embeds.flatten(2).transpose(1, 2)
169 |
170 | embeddings = embeddings + self.position_embedding(self.position_ids)
171 | return embeddings
172 |
173 |
174 | class SigLipAttention(nn.Module):
175 | """Multi-headed attention from 'Attention Is All You Need' paper"""
176 |
177 | # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
178 | def __init__(self, config):
179 | super().__init__()
180 | self.config = config
181 | self.embed_dim = config.hidden_size
182 | self.num_heads = config.num_attention_heads
183 | self.head_dim = self.embed_dim // self.num_heads
184 | if self.head_dim * self.num_heads != self.embed_dim:
185 | raise ValueError(
186 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
187 | f" {self.num_heads})."
188 | )
189 | self.scale = self.head_dim ** -0.5
190 | self.dropout = config.attention_dropout
191 |
192 | self.k_proj = nn.Linear(self.embed_dim, self.embed_dim)
193 | self.v_proj = nn.Linear(self.embed_dim, self.embed_dim)
194 | self.q_proj = nn.Linear(self.embed_dim, self.embed_dim)
195 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim)
196 |
197 | def forward(
198 | self,
199 | hidden_states: torch.Tensor,
200 | attention_mask: Optional[torch.Tensor] = None,
201 | output_attentions: Optional[bool] = False,
202 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
203 | """Input shape: Batch x Time x Channel"""
204 |
205 | batch_size, q_len, _ = hidden_states.size()
206 |
207 | query_states = self.q_proj(hidden_states)
208 | key_states = self.k_proj(hidden_states)
209 | value_states = self.v_proj(hidden_states)
210 |
211 | query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
212 | key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
213 | value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
214 |
215 | k_v_seq_len = key_states.shape[-2]
216 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale
217 |
218 | if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len):
219 | raise ValueError(
220 | f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is"
221 | f" {attn_weights.size()}"
222 | )
223 |
224 | if attention_mask is not None:
225 | if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len):
226 | raise ValueError(
227 | f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}"
228 | )
229 | attn_weights = attn_weights + attention_mask
230 |
231 | # upcast attention to fp32
232 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
233 | attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
234 | attn_output = torch.matmul(attn_weights, value_states)
235 |
236 | if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim):
237 | raise ValueError(
238 | f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is"
239 | f" {attn_output.size()}"
240 | )
241 |
242 | attn_output = attn_output.transpose(1, 2).contiguous()
243 | attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim)
244 |
245 | attn_output = self.out_proj(attn_output)
246 |
247 | return attn_output, attn_weights
248 |
249 |
250 | # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip
251 | class SigLipMLP(nn.Module):
252 | def __init__(self, config):
253 | super().__init__()
254 | self.config = config
255 | self.activation_fn = ACT2FN[config.hidden_act]
256 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
257 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
258 |
259 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
260 | hidden_states = self.fc1(hidden_states)
261 | hidden_states = self.activation_fn(hidden_states)
262 | hidden_states = self.fc2(hidden_states)
263 | return hidden_states
264 |
265 |
266 | # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip
267 | class SigLipEncoderLayer(nn.Module):
268 | def __init__(self, config: SigLipVisionConfig):
269 | super().__init__()
270 | self.embed_dim = config.hidden_size
271 | self.self_attn = SigLipAttention(config)
272 | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
273 | self.mlp = SigLipMLP(config)
274 | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
275 |
276 | # Ignore copy
277 | def forward(
278 | self,
279 | hidden_states: torch.Tensor,
280 | attention_mask: torch.Tensor,
281 | output_attentions: Optional[bool] = False,
282 | ) -> Tuple[torch.FloatTensor]:
283 | """
284 | Args:
285 | hidden_states (`torch.FloatTensor`):
286 | Input to the layer of shape `(batch, seq_len, embed_dim)`.
287 | attention_mask (`torch.FloatTensor`):
288 | Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values.
289 | output_attentions (`bool`, *optional*, defaults to `False`):
290 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
291 | returned tensors for more detail.
292 | """
293 | residual = hidden_states
294 |
295 | hidden_states = self.layer_norm1(hidden_states)
296 | hidden_states, attn_weights = self.self_attn(
297 | hidden_states=hidden_states,
298 | attention_mask=attention_mask,
299 | output_attentions=output_attentions,
300 | )
301 | hidden_states = residual + hidden_states
302 |
303 | residual = hidden_states
304 | hidden_states = self.layer_norm2(hidden_states)
305 | hidden_states = self.mlp(hidden_states)
306 | hidden_states = residual + hidden_states
307 |
308 | outputs = (hidden_states,)
309 |
310 | if output_attentions:
311 | outputs += (attn_weights,)
312 |
313 | return outputs
314 |
315 |
316 | class SigLipPreTrainedModel(PreTrainedModel):
317 | """
318 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
319 | models.
320 | """
321 |
322 | config_class = SigLipVisionConfig
323 | base_model_prefix = "siglip"
324 | supports_gradient_checkpointing = True
325 |
326 | def _init_weights(self, module):
327 | """Initialize the weights"""
328 | pass
329 |
330 |
331 | # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip
332 | class SigLipEncoder(nn.Module):
333 | """
334 | Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
335 | [`SigLipEncoderLayer`].
336 |
337 | Args:
338 | config: SigLipVisionConfig
339 | """
340 |
341 | def __init__(self, config: SigLipVisionConfig):
342 | super().__init__()
343 | self.config = config
344 | self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)])
345 | self.gradient_checkpointing = False
346 |
347 | # Ignore copy
348 | def forward(
349 | self,
350 | inputs_embeds,
351 | attention_mask: Optional[torch.Tensor] = None,
352 | output_attentions: Optional[bool] = None,
353 | output_hidden_states: Optional[bool] = None,
354 | return_dict: Optional[bool] = None,
355 | ) -> Union[Tuple, BaseModelOutput]:
356 | r"""
357 | Args:
358 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
359 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
360 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors
361 | than the model's internal embedding lookup matrix.
362 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
363 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
364 |
365 | - 1 for tokens that are **not masked**,
366 | - 0 for tokens that are **masked**.
367 |
368 | [What are attention masks?](../glossary#attention-mask)
369 | output_attentions (`bool`, *optional*):
370 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under
371 | returned tensors for more detail.
372 | output_hidden_states (`bool`, *optional*):
373 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
374 | for more detail.
375 | return_dict (`bool`, *optional*):
376 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
377 | """
378 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
379 | output_hidden_states = (
380 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
381 | )
382 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
383 |
384 | encoder_states = () if output_hidden_states else None
385 | all_attentions = () if output_attentions else None
386 |
387 | hidden_states = inputs_embeds
388 | for encoder_layer in self.layers:
389 | if output_hidden_states:
390 | encoder_states = encoder_states + (hidden_states,)
391 | if self.gradient_checkpointing and self.training:
392 | layer_outputs = self._gradient_checkpointing_func(
393 | encoder_layer.__call__,
394 | hidden_states,
395 | attention_mask,
396 | output_attentions,
397 | )
398 | else:
399 | layer_outputs = encoder_layer(
400 | hidden_states,
401 | attention_mask,
402 | output_attentions=output_attentions,
403 | )
404 |
405 | hidden_states = layer_outputs[0]
406 |
407 | if output_attentions:
408 | all_attentions = all_attentions + (layer_outputs[1],)
409 |
410 | if output_hidden_states:
411 | encoder_states = encoder_states + (hidden_states,)
412 |
413 | if not return_dict:
414 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
415 | return BaseModelOutput(
416 | last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
417 | )
418 |
419 |
420 | class SigLipVisionTransformer(nn.Module):
421 | def __init__(self, config: SigLipVisionConfig):
422 | super().__init__()
423 | self.config = config
424 | embed_dim = config.hidden_size
425 |
426 | self.embeddings = SigLipVisionEmbeddings(config)
427 | self.encoder = SigLipEncoder(config)
428 | self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
429 | self.head = SigLipMultiheadAttentionPoolingHead(config)
430 |
431 | def forward(
432 | self,
433 | pixel_values,
434 | output_attentions: Optional[bool] = None,
435 | output_hidden_states: Optional[bool] = None,
436 | return_dict: Optional[bool] = None,
437 | ) -> Union[Tuple, BaseModelOutputWithPooling]:
438 | r"""
439 | Returns:
440 |
441 | """
442 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
443 | output_hidden_states = (
444 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
445 | )
446 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
447 |
448 | hidden_states = self.embeddings(pixel_values)
449 |
450 | encoder_outputs = self.encoder(
451 | inputs_embeds=hidden_states,
452 | output_attentions=output_attentions,
453 | output_hidden_states=output_hidden_states,
454 | return_dict=return_dict,
455 | )
456 |
457 | last_hidden_state = encoder_outputs[0]
458 | last_hidden_state = self.post_layernorm(last_hidden_state)
459 |
460 | pooled_output = self.head(last_hidden_state)
461 |
462 | if not return_dict:
463 | return (last_hidden_state, pooled_output) + encoder_outputs[1:]
464 |
465 | return BaseModelOutputWithPooling(
466 | last_hidden_state=last_hidden_state,
467 | pooler_output=pooled_output,
468 | hidden_states=encoder_outputs.hidden_states,
469 | attentions=encoder_outputs.attentions,
470 | )
471 |
472 |
473 | class SigLipMultiheadAttentionPoolingHead(nn.Module):
474 | """Multihead Attention Pooling."""
475 |
476 | def __init__(self, config: SigLipVisionConfig):
477 | super().__init__()
478 |
479 | self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
480 | self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True)
481 | self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
482 | self.mlp = SigLipMLP(config)
483 |
484 | def forward(self, hidden_state):
485 | batch_size = hidden_state.shape[0]
486 | probe = self.probe.repeat(batch_size, 1, 1)
487 |
488 | hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
489 |
490 | residual = hidden_state
491 | hidden_state = self.layernorm(hidden_state)
492 | hidden_state = residual + self.mlp(hidden_state)
493 |
494 | return hidden_state[:, 0]
495 |
496 |
497 | class SigLipVisionModel(SigLipPreTrainedModel):
498 | config_class = SigLipVisionConfig
499 | main_input_name = "pixel_values"
500 | _no_split_modules = ["SigLipEncoderLayer"]
501 |
502 | def __init__(self, config: SigLipVisionConfig):
503 | super().__init__(config)
504 |
505 | self.vision_model = SigLipVisionTransformer(config)
506 |
507 | # Initialize weights and apply final processing
508 | self.post_init()
509 |
510 | def get_input_embeddings(self) -> nn.Module:
511 | return self.vision_model.embeddings.patch_embedding
512 |
513 | def forward(
514 | self,
515 | pixel_values,
516 | output_attentions: Optional[bool] = None,
517 | output_hidden_states: Optional[bool] = None,
518 | return_dict: Optional[bool] = None,
519 | ) -> Union[Tuple, BaseModelOutputWithPooling]:
520 | r"""
521 | Returns:
522 |
523 | Examples:
524 |
525 | ```python
526 | >>> from PIL import Image
527 | >>> import requests
528 | >>> from transformers import AutoProcessor, SigLipVisionModel
529 |
530 | >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224")
531 | >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
532 |
533 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
534 | >>> image = Image.open(requests.get(url, stream=True).raw)
535 |
536 | >>> inputs = processor(images=image, return_tensors="pt")
537 |
538 | >>> outputs = model(**inputs)
539 | >>> last_hidden_state = outputs.last_hidden_state
540 | >>> pooled_output = outputs.pooler_output # pooled features
541 | ```"""
542 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
543 |
544 | return self.vision_model(
545 | pixel_values=pixel_values,
546 | output_attentions=output_attentions,
547 | output_hidden_states=output_hidden_states,
548 | return_dict=return_dict,
549 | )
550 |
551 |
552 | class SigLipVisionTower(nn.Module):
553 | def __init__(self, vision_tower, vision_tower_cfg, delay_load=False):
554 | super().__init__()
555 |
556 | self.is_loaded = False
557 |
558 | self.config = SigLipVisionConfig()
559 |
560 | self.vision_tower_name = vision_tower
561 |
562 | self.image_processor = SigLipImageProcessor()
563 |
564 | if not delay_load:
565 | self.load_model()
566 | else:
567 | self.cfg_only = self.config
568 |
569 | def load_model(self):
570 | if self.is_loaded:
571 | return
572 |
573 | self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name)
574 |
575 | del self.vision_tower.vision_model.encoder.layers[-1:]
576 | self.vision_tower.vision_model.head = nn.Identity()
577 | self.vision_tower.requires_grad_(False)
578 | self.vision_tower.eval()
579 |
580 | self.is_loaded = True
581 |
582 | @torch.no_grad()
583 | def forward(self, images):
584 | if type(images) is list:
585 | image_features = []
586 | for image in images:
587 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0),
588 | output_hidden_states=True)
589 | image_feature = image_forward_out.hidden_states[-1].to(image.dtype)
590 | assert image_features.shape[-2] == 729
591 | image_features.append(image_feature)
592 | else:
593 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype),
594 | output_hidden_states=True)
595 | image_features = image_forward_outs.hidden_states[-1].to(images.dtype)
596 | assert image_features.shape[-2] == 729
597 |
598 | return image_features
599 |
600 | @property
601 | def dummy_feature(self):
602 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
603 |
604 | @property
605 | def dtype(self):
606 | for p in self.vision_tower.parameters():
607 | return p.dtype
608 |
609 | @property
610 | def device(self):
611 | for p in self.vision_tower.parameters():
612 | return p.device
613 |
614 | @property
615 | def hidden_size(self):
616 | return self.config.hidden_size
617 |
618 | @property
619 | def num_patches(self):
620 | return (self.config.image_size // self.config.patch_size) ** 2
621 |
--------------------------------------------------------------------------------
/cerule/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import re
2 | import math
3 | from torch import nn
4 | from functools import partial
5 | from timm.layers.norm_act import LayerNormAct2d
6 | from torchvision.ops.misc import SqueezeExcitation as SELayer
7 | from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig
8 |
9 |
10 | class IdentityMap(nn.Module):
11 | def __init__(self):
12 | super().__init__()
13 |
14 | def forward(self, x, *args, **kwargs):
15 | return x
16 |
17 | @property
18 | def config(self):
19 | return {"mm_projector_type": 'identity'}
20 |
21 |
22 | class Minigpt(nn.Module):
23 | def __init__(self, config=None):
24 | super(Minigpt, self).__init__()
25 | # c*4 is the input size, and c is the output size for the linear layer
26 | inc, ouc = config.mm_hidden_size, config.hidden_size
27 | self.linear = nn.Linear(inc * 4, ouc)
28 |
29 | def forward(self, x):
30 | # x is the input tensor with shape [b, num_tokens, c]
31 | b, num_tokens, c = x.shape
32 |
33 | # Check if num_tokens is divisible by 4
34 | if num_tokens % 4 != 0:
35 | raise ValueError("num_tokens must be divisible by 4")
36 |
37 | # Reshape x to [b, num_tokens/4, c*4]
38 | x = x.view(b, num_tokens // 4, c * 4)
39 |
40 | # Apply the linear transformation
41 | x = self.linear(x)
42 | return x
43 |
44 |
45 | class Vanilla(nn.Module):
46 | def __init__(self, config=None):
47 | super(Vanilla, self).__init__()
48 | # c*4 is the input size, and c is the output size for the linear layer
49 | inc, ouc = config.mm_hidden_size, config.hidden_size
50 | self.linear = nn.Linear(inc * 4, ouc)
51 |
52 | def forward(self, x):
53 | b, num_tokens, c = x.shape
54 |
55 | # Check if num_tokens is divisible by 4
56 | if num_tokens % 4 != 0:
57 | raise ValueError("num_tokens must be divisible by 4")
58 |
59 | # First, reshape to [b, num_tokens//4, 4, c]
60 | x = x.view(b, num_tokens // 4, 4, c)
61 |
62 | # Then, permute to interleave the tokens
63 | x = x.permute(0, 1, 3, 2).contiguous()
64 |
65 | # Finally, reshape to [b, num_tokens//4, c*4] to interleave features of 4 tokens
66 | x = x.view(b, num_tokens // 4, c * 4)
67 |
68 | # Apply the linear transformation
69 | x = self.linear(x)
70 | return x
71 |
72 |
73 | class LDPBlock(nn.Module):
74 | # Lightweight Downsample Projector Block
75 |
76 | def __init__(self, config=None):
77 | super().__init__()
78 |
79 | inc, ouc = config.mm_hidden_size, config.hidden_size
80 | layer_norm = partial(LayerNormAct2d, act_layer=None)
81 | se_layer = partial(SELayer, scale_activation=nn.Hardsigmoid)
82 | self.mlp = nn.Sequential(
83 | nn.Identity(), nn.Linear(inc, ouc), nn.GELU(), nn.Linear(ouc, ouc)
84 | )
85 | self.mb_block = nn.Sequential(
86 | nn.Identity(),
87 | InvertedResidual(InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 1, 1, 1), layer_norm, se_layer),
88 | InvertedResidual(InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 2, 1, 1), layer_norm, se_layer)
89 | )
90 |
91 | def forward(self, x):
92 | b, num_tokens, c = x.shape
93 | h = int(math.sqrt(num_tokens))
94 | x = self.mlp(x)
95 | x = x.permute(0, 2, 1).reshape(b, -1, h, h)
96 | x = self.mb_block(x)
97 | x = x.flatten(2).permute(0, 2, 1)
98 | return x
99 |
100 |
101 | class LDPNetProjector(nn.Module):
102 |
103 | def __init__(self, config=None):
104 | super().__init__()
105 | self.model = LDPBlock(config)
106 |
107 | def forward(self, x):
108 | return self.model(x)
109 |
110 |
111 | class SPP(nn.Module):
112 |
113 | def __init__(self, config=None, projector_type='v1'):
114 | super().__init__()
115 |
116 | self.projector_type = projector_type
117 |
118 | inc, ouc = config.mm_hidden_size, config.hidden_size
119 | self.linear_0 = nn.Linear(inc, inc)
120 |
121 | self.linear_1 = nn.Linear(inc, ouc)
122 |
123 | self.pooling = nn.AvgPool2d(kernel_size=2)
124 |
125 | self.linear_2 = nn.Linear(ouc, ouc)
126 |
127 | def forward(self, x):
128 | b, num_tokens, c = x.shape
129 | h = int(math.sqrt(num_tokens))
130 | if 'v1' in self.projector_type:
131 | x = self.linear_1(x)
132 | x = x.permute(0, 2, 1).reshape(b, -1, h, h)
133 | x = self.pooling(x)
134 | x = x.flatten(2).permute(0, 2, 1)
135 | x = self.linear_2(x)
136 | elif 'v2' in self.projector_type:
137 | x = self.linear_1(x)
138 | x = self.linear_2(x)
139 | x = x.permute(0, 2, 1).reshape(b, -1, h, h)
140 | x = self.pooling(x)
141 | x = x.flatten(2).permute(0, 2, 1)
142 | elif 'v3' in self.projector_type:
143 | x = self.linear_0(x)
144 | x = x.permute(0, 2, 1).reshape(b, -1, h, h)
145 | x = self.pooling(x)
146 | x = x.flatten(2).permute(0, 2, 1)
147 | x = self.linear_1(x)
148 | x = self.linear_2(x)
149 | return x
150 |
151 |
152 | def build_vision_projector(config, delay_load=False, **kwargs):
153 | projector_type = getattr(config, 'mm_projector_type', 'mlp2x_gelu')
154 |
155 | if projector_type == 'linear':
156 | return nn.Linear(config.mm_hidden_size, config.hidden_size)
157 |
158 | elif projector_type.startswith('mlp'):
159 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
160 | if mlp_gelu_match:
161 | mlp_depth = int(mlp_gelu_match.group(1))
162 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
163 | for _ in range(1, mlp_depth):
164 | modules.append(nn.GELU())
165 | modules.append(nn.Linear(config.hidden_size, config.hidden_size))
166 | return nn.Sequential(*modules)
167 |
168 | elif projector_type.startswith('spp'):
169 | return SPP(config, projector_type)
170 |
171 | elif projector_type == 'ldp':
172 | return LDPNetProjector(config)
173 |
174 | elif projector_type == 'vanilla':
175 | return Vanilla(config)
176 |
177 | elif projector_type == 'minigpt':
178 | return Minigpt(config)
179 |
180 | elif projector_type == 'identity':
181 | return IdentityMap()
182 |
183 | raise ValueError(f'Unknown projector type: {projector_type}')
184 |
--------------------------------------------------------------------------------
/cerule/serve/cli.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import requests
4 |
5 | from PIL import Image
6 | from io import BytesIO
7 | from transformers import TextStreamer
8 |
9 | from cerule.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
10 | from cerule.conversation import conv_templates, SeparatorStyle
11 | from cerule.model.builder import load_pretrained_model
12 | from cerule.util.utils import disable_torch_init
13 | from cerule.util.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, \
14 | KeywordsStoppingCriteria
15 |
16 |
17 | def load_image(image_file):
18 | if image_file.startswith('http://') or image_file.startswith('https://'):
19 | response = requests.get(image_file)
20 | image = Image.open(BytesIO(response.content)).convert('RGB')
21 | else:
22 | image = Image.open(image_file).convert('RGB')
23 | return image
24 |
25 |
26 | def main(args):
27 | # Model
28 | disable_torch_init()
29 |
30 | model_name = get_model_name_from_path(args.model_path)
31 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name,
32 | args.model_type, args.load_8bit,
33 | args.load_4bit, device=args.device)
34 |
35 | conv_mode = "gemma_instruct"
36 |
37 | if args.conv_mode is not None and conv_mode != args.conv_mode:
38 | print(
39 | '[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode,
40 | args.conv_mode,
41 | args.conv_mode))
42 | else:
43 | args.conv_mode = conv_mode
44 |
45 | conv = conv_templates[args.conv_mode].copy()
46 | roles = conv.roles
47 |
48 | image = load_image(args.image_file)
49 | # Similar operation in model_worker.py
50 | image_tensor = process_images([image], image_processor, model.config)
51 | if type(image_tensor) is list:
52 | image_tensor = [image.to(model.device, dtype=model.dtype) for image in image_tensor]
53 | else:
54 | image_tensor = image_tensor.to(model.device, dtype=model.dtype)
55 |
56 | while True:
57 | try:
58 | inp = input(f"{roles[0]}: ")
59 | except EOFError:
60 | inp = ""
61 | if not inp:
62 | print("exit...")
63 | break
64 |
65 | print(f"{roles[1]}: ", end="")
66 |
67 | if image is not None:
68 | # first message
69 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
70 | conv.append_message(conv.roles[0], inp)
71 | image = None
72 | else:
73 | conv.append_message(conv.roles[0], inp)
74 | conv.append_message(conv.roles[1], None)
75 | prompt = conv.get_prompt()
76 |
77 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(
78 | model.device)
79 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
80 | keywords = [stop_str]
81 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
82 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
83 |
84 | with torch.inference_mode():
85 | output_ids = model.generate(
86 | input_ids,
87 | images=image_tensor,
88 | do_sample=True if args.temperature > 0 else False,
89 | temperature=args.temperature,
90 | max_new_tokens=args.max_new_tokens,
91 | streamer=streamer,
92 | use_cache=False, # Keep use_cache=False. use_cache=True raises some weird error
93 | stopping_criteria=[stopping_criteria])
94 |
95 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
96 | conv.messages[-1][-1] = outputs
97 |
98 | if args.debug:
99 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
100 |
101 |
102 | if __name__ == "__main__":
103 | parser = argparse.ArgumentParser()
104 | parser.add_argument("--model-path", type=str, default=None)
105 | parser.add_argument("--model-base", type=str, default=None)
106 | parser.add_argument("--model-type", type=str, default="gemma")
107 | parser.add_argument("--image-file", type=str, required=True)
108 | parser.add_argument("--device", type=str, default="cuda")
109 | parser.add_argument("--conv-mode", type=str, default=None)
110 | parser.add_argument("--temperature", type=float, default=0.2)
111 | parser.add_argument("--max-new-tokens", type=int, default=512)
112 | parser.add_argument("--load-8bit", action="store_true")
113 | parser.add_argument("--load-4bit", action="store_true")
114 | parser.add_argument("--debug", action="store_true")
115 | args = parser.parse_args()
116 | main(args)
117 |
--------------------------------------------------------------------------------
/cerule/serve/controller.py:
--------------------------------------------------------------------------------
1 | """
2 | A controller manages distributed workers.
3 | It sends worker addresses to clients.
4 | """
5 | import argparse
6 | import dataclasses
7 | import threading
8 | import json
9 | import time
10 | import numpy as np
11 | import requests
12 | import uvicorn
13 |
14 | from typing import List
15 | from enum import Enum, auto
16 | from fastapi import FastAPI, Request
17 | from fastapi.responses import StreamingResponse
18 |
19 | from cerule.constants import CONTROLLER_HEART_BEAT_EXPIRATION
20 | from cerule.util.utils import build_logger, server_error_msg
21 |
22 | logger = build_logger("controller", "controller.log")
23 |
24 |
25 | class DispatchMethod(Enum):
26 | LOTTERY = auto()
27 | SHORTEST_QUEUE = auto()
28 |
29 | @classmethod
30 | def from_str(cls, name):
31 | if name == "lottery":
32 | return cls.LOTTERY
33 | elif name == "shortest_queue":
34 | return cls.SHORTEST_QUEUE
35 | else:
36 | raise ValueError(f"Invalid dispatch method")
37 |
38 |
39 | @dataclasses.dataclass
40 | class WorkerInfo:
41 | model_names: List[str]
42 | speed: int
43 | queue_length: int
44 | check_heart_beat: bool
45 | last_heart_beat: str
46 |
47 |
48 | def heart_beat_controller(controller):
49 | while True:
50 | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
51 | controller.remove_stable_workers_by_expiration()
52 |
53 |
54 | class Controller:
55 | def __init__(self, dispatch_method: str):
56 | # Dict[str -> WorkerInfo]
57 | self.worker_info = {}
58 | self.dispatch_method = DispatchMethod.from_str(dispatch_method)
59 |
60 | self.heart_beat_thread = threading.Thread(
61 | target=heart_beat_controller, args=(self,))
62 | self.heart_beat_thread.start()
63 |
64 | logger.info("Init controller")
65 |
66 | def register_worker(self, worker_name: str, check_heart_beat: bool,
67 | worker_status: dict):
68 | if worker_name not in self.worker_info:
69 | logger.info(f"Register a new worker: {worker_name}")
70 | else:
71 | logger.info(f"Register an existing worker: {worker_name}")
72 |
73 | if not worker_status:
74 | worker_status = self.get_worker_status(worker_name)
75 | if not worker_status:
76 | return False
77 |
78 | self.worker_info[worker_name] = WorkerInfo(
79 | worker_status["model_names"], worker_status["speed"], worker_status["queue_length"],
80 | check_heart_beat, time.time())
81 |
82 | logger.info(f"Register done: {worker_name}, {worker_status}")
83 | return True
84 |
85 | def get_worker_status(self, worker_name: str):
86 | try:
87 | r = requests.post(worker_name + "/worker_get_status", timeout=5)
88 | except requests.exceptions.RequestException as e:
89 | logger.error(f"Get status fails: {worker_name}, {e}")
90 | return None
91 |
92 | if r.status_code != 200:
93 | logger.error(f"Get status fails: {worker_name}, {r}")
94 | return None
95 |
96 | return r.json()
97 |
98 | def remove_worker(self, worker_name: str):
99 | del self.worker_info[worker_name]
100 |
101 | def refresh_all_workers(self):
102 | old_info = dict(self.worker_info)
103 | self.worker_info = {}
104 |
105 | for w_name, w_info in old_info.items():
106 | if not self.register_worker(w_name, w_info.check_heart_beat, None):
107 | logger.info(f"Remove stale worker: {w_name}")
108 |
109 | def list_models(self):
110 | model_names = set()
111 |
112 | for w_name, w_info in self.worker_info.items():
113 | model_names.update(w_info.model_names)
114 |
115 | return list(model_names)
116 |
117 | def get_worker_address(self, model_name: str):
118 | if self.dispatch_method == DispatchMethod.LOTTERY:
119 | worker_names = []
120 | worker_speeds = []
121 | for w_name, w_info in self.worker_info.items():
122 | if model_name in w_info.model_names:
123 | worker_names.append(w_name)
124 | worker_speeds.append(w_info.speed)
125 | worker_speeds = np.array(worker_speeds, dtype=np.float32)
126 | norm = np.sum(worker_speeds)
127 | if norm < 1e-4:
128 | return ""
129 | worker_speeds = worker_speeds / norm
130 |
131 | pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
132 | worker_name = worker_names[pt]
133 | return worker_name
134 |
135 | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
136 | worker_names = []
137 | worker_qlen = []
138 | for w_name, w_info in self.worker_info.items():
139 | if model_name in w_info.model_names:
140 | worker_names.append(w_name)
141 | worker_qlen.append(w_info.queue_length / w_info.speed)
142 | if len(worker_names) == 0:
143 | return ""
144 | min_index = np.argmin(worker_qlen)
145 | w_name = worker_names[min_index]
146 | self.worker_info[w_name].queue_length += 1
147 | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}")
148 | return w_name
149 | else:
150 | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
151 |
152 | def receive_heart_beat(self, worker_name: str, queue_length: int):
153 | if worker_name not in self.worker_info:
154 | logger.info(f"Receive unknown heart beat. {worker_name}")
155 | return False
156 |
157 | self.worker_info[worker_name].queue_length = queue_length
158 | self.worker_info[worker_name].last_heart_beat = time.time()
159 | logger.info(f"Receive heart beat. {worker_name}")
160 | return True
161 |
162 | def remove_stable_workers_by_expiration(self):
163 | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
164 | to_delete = []
165 | for worker_name, w_info in self.worker_info.items():
166 | if w_info.check_heart_beat and w_info.last_heart_beat < expire:
167 | to_delete.append(worker_name)
168 |
169 | for worker_name in to_delete:
170 | self.remove_worker(worker_name)
171 |
172 | def worker_api_generate_stream(self, params):
173 | worker_addr = self.get_worker_address(params["model"])
174 | if not worker_addr:
175 | logger.info(f"no worker: {params['model']}")
176 | ret = {
177 | "text": server_error_msg,
178 | "error_code": 2,
179 | }
180 | yield json.dumps(ret).encode() + b"\0"
181 |
182 | try:
183 | response = requests.post(worker_addr + "/worker_generate_stream",
184 | json=params, stream=True, timeout=5)
185 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
186 | if chunk:
187 | yield chunk + b"\0"
188 | except requests.exceptions.RequestException as e:
189 | logger.info(f"worker timeout: {worker_addr}")
190 | ret = {
191 | "text": server_error_msg,
192 | "error_code": 3,
193 | }
194 | yield json.dumps(ret).encode() + b"\0"
195 |
196 | # Let the controller act as a worker to achieve hierarchical
197 | # management. This can be used to connect isolated sub networks.
198 | def worker_api_get_status(self):
199 | model_names = set()
200 | speed = 0
201 | queue_length = 0
202 |
203 | for w_name in self.worker_info:
204 | worker_status = self.get_worker_status(w_name)
205 | if worker_status is not None:
206 | model_names.update(worker_status["model_names"])
207 | speed += worker_status["speed"]
208 | queue_length += worker_status["queue_length"]
209 |
210 | return {
211 | "model_names": list(model_names),
212 | "speed": speed,
213 | "queue_length": queue_length,
214 | }
215 |
216 |
217 | app = FastAPI()
218 |
219 |
220 | @app.post("/register_worker")
221 | async def register_worker(request: Request):
222 | data = await request.json()
223 | controller.register_worker(
224 | data["worker_name"], data["check_heart_beat"],
225 | data.get("worker_status", None))
226 |
227 |
228 | @app.post("/refresh_all_workers")
229 | async def refresh_all_workers():
230 | models = controller.refresh_all_workers()
231 |
232 |
233 | @app.post("/list_models")
234 | async def list_models():
235 | models = controller.list_models()
236 | return {"models": models}
237 |
238 |
239 | @app.post("/get_worker_address")
240 | async def get_worker_address(request: Request):
241 | data = await request.json()
242 | addr = controller.get_worker_address(data["model"])
243 | return {"address": addr}
244 |
245 |
246 | @app.post("/receive_heart_beat")
247 | async def receive_heart_beat(request: Request):
248 | data = await request.json()
249 | exist = controller.receive_heart_beat(
250 | data["worker_name"], data["queue_length"])
251 | return {"exist": exist}
252 |
253 |
254 | @app.post("/worker_generate_stream")
255 | async def worker_api_generate_stream(request: Request):
256 | params = await request.json()
257 | generator = controller.worker_api_generate_stream(params)
258 | return StreamingResponse(generator)
259 |
260 |
261 | @app.post("/worker_get_status")
262 | async def worker_api_get_status(request: Request):
263 | return controller.worker_api_get_status()
264 |
265 |
266 | if __name__ == "__main__":
267 | parser = argparse.ArgumentParser()
268 | parser.add_argument("--host", type=str, default="localhost")
269 | parser.add_argument("--port", type=int, default=21001)
270 | parser.add_argument("--dispatch-method", type=str, choices=["lottery", "shortest_queue"], default="shortest_queue")
271 | args = parser.parse_args()
272 | logger.info(f"args: {args}")
273 |
274 | controller = Controller(args.dispatch_method)
275 | log_config = uvicorn.config.LOGGING_CONFIG
276 | log_config['handlers']['default']['stream'] = 'ext://sys.stdout'
277 | uvicorn.run(app, host=args.host, port=args.port, log_level="info")
278 |
--------------------------------------------------------------------------------
/cerule/serve/examples/example_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/cerule/serve/examples/example_1.png
--------------------------------------------------------------------------------
/cerule/serve/examples/example_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/cerule/serve/examples/example_2.png
--------------------------------------------------------------------------------
/cerule/serve/gradio_web_server.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import os
5 | import time
6 | import gradio as gr
7 | import requests
8 | import hashlib
9 | import pypandoc
10 | import base64
11 |
12 | from io import BytesIO
13 |
14 | from cerule.conversation import (default_conversation, conv_templates, SeparatorStyle)
15 | from cerule.constants import LOGDIR
16 | from cerule.util.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg)
17 |
18 | logger = build_logger("gradio_web_server", "gradio_web_server.log")
19 |
20 | headers = {"User-Agent": "cerule Client"}
21 |
22 | no_change_btn = gr.update()
23 | enable_btn = gr.update(interactive=True)
24 | disable_btn = gr.update(interactive=False)
25 |
26 | priority = {
27 | "cerule": "aaaaaaa",
28 | }
29 |
30 |
31 | def get_conv_log_filename():
32 | t = datetime.datetime.now()
33 | name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json")
34 | return name
35 |
36 |
37 | def get_model_list():
38 | ret = requests.post(args.controller_url + "/refresh_all_workers")
39 | assert ret.status_code == 200
40 | ret = requests.post(args.controller_url + "/list_models")
41 | models = ret.json()["models"]
42 | models.sort(key=lambda x: priority.get(x, x))
43 | logger.info(f"Models: {models}")
44 | return models
45 |
46 |
47 | get_window_url_params = """
48 | function() {
49 | const params = new URLSearchParams(window.location.search);
50 | url_params = Object.fromEntries(params);
51 | console.log(url_params);
52 | return url_params;
53 | }
54 | """
55 |
56 |
57 | def load_demo(url_params, request: gr.Request):
58 | logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}")
59 |
60 | dropdown_update = gr.update(visible=True)
61 | if "model" in url_params:
62 | model = url_params["model"]
63 | if model in models:
64 | dropdown_update = gr.update(
65 | value=model, visible=True)
66 |
67 | state = default_conversation.copy()
68 | return state, dropdown_update
69 |
70 |
71 | def load_demo_refresh_model_list(request: gr.Request):
72 | logger.info(f"load_demo. ip: {request.client.host}")
73 | models = get_model_list()
74 | state = default_conversation.copy()
75 | dropdown_update = gr.update(
76 | choices=models,
77 | value=models[0] if len(models) > 0 else ""
78 | )
79 | return state, dropdown_update
80 |
81 |
82 | def vote_last_response(state, vote_type, model_selector, request: gr.Request):
83 | with open(get_conv_log_filename(), "a") as fout:
84 | data = {
85 | "tstamp": round(time.time(), 4),
86 | "type": vote_type,
87 | "model": model_selector,
88 | "state": state.dict(),
89 | "ip": request.client.host,
90 | }
91 | fout.write(json.dumps(data) + "\n")
92 |
93 |
94 | def upvote_last_response(state, model_selector, request: gr.Request):
95 | logger.info(f"upvote. ip: {request.client.host}")
96 | vote_last_response(state, "upvote", model_selector, request)
97 | return ("",) + (disable_btn,) * 3
98 |
99 |
100 | def downvote_last_response(state, model_selector, request: gr.Request):
101 | logger.info(f"downvote. ip: {request.client.host}")
102 | vote_last_response(state, "downvote", model_selector, request)
103 | return ("",) + (disable_btn,) * 3
104 |
105 |
106 | def flag_last_response(state, model_selector, request: gr.Request):
107 | logger.info(f"flag. ip: {request.client.host}")
108 | vote_last_response(state, "flag", model_selector, request)
109 | return ("",) + (disable_btn,) * 3
110 |
111 |
112 | def regenerate(state, image_process_mode, request: gr.Request):
113 | logger.info(f"regenerate. ip: {request.client.host}")
114 | state.messages[-1][-1] = None
115 | prev_human_msg = state.messages[-2]
116 | if type(prev_human_msg[1]) in (tuple, list):
117 | prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
118 | state.skip_next = False
119 | return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 3
120 |
121 |
122 | def clear_history(request: gr.Request):
123 | logger.info(f"clear_history. ip: {request.client.host}")
124 | state = default_conversation.copy()
125 | return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 3
126 |
127 |
128 | def save_conversation(conversation):
129 | print("save_conversation_wrapper is called")
130 | html_content = ""
131 |
132 | for role, message in conversation.messages:
133 | if isinstance(message, str): # only text
134 | html_content += f"