├── LICENSE ├── README.md ├── cerule ├── constants.py ├── conversation.py ├── model │ ├── __init__.py │ ├── builder.py │ ├── cerule_arch.py │ ├── language_model │ │ ├── cerule_gemma.py │ │ └── gemma │ │ │ ├── __init__.py │ │ │ ├── configuration_gemma.py │ │ │ └── modeling_gemma.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ ├── clip │ │ │ └── clip_encoder.py │ │ ├── eva_clip │ │ │ ├── eva_clip_encoder.py │ │ │ ├── eva_clip_processors.py │ │ │ └── eva_vit.py │ │ └── siglip │ │ │ └── siglip_encoder.py │ └── multimodal_projector │ │ └── builder.py ├── serve │ ├── cli.py │ ├── controller.py │ ├── examples │ │ ├── example_1.png │ │ └── example_2.png │ ├── gradio_web_server.py │ ├── model_worker.py │ └── register_worker.py ├── train │ ├── cerule_trainer.py │ └── train.py └── util │ ├── data_utils.py │ ├── mm_utils.py │ └── utils.py ├── examples ├── YHyRn8r.png ├── astronaut.png ├── bridge.png ├── design.jpg ├── extreme_ironing.jpg ├── google.png ├── graph.jpg ├── graph1.jpg ├── image.png ├── mario.png └── sting.png ├── pyproject.toml └── script ├── deepspeed ├── scripts │ └── zero3_offload.json ├── zero2.json └── zero3.json ├── merge_lora_weights.py └── train ├── finetune_full.sh ├── finetune_lora.sh └── pretrain.sh /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | --- 2 | license: apache-2.0 3 | language: 4 | - en 5 | pipeline_tag: image-text-to-text 6 | --- 7 | 8 | # Cerule - A Tiny Mighty Vision Model 9 | ### Based on Google's - Gemma-2b + SigLIP 10 | 11 | 12 | 13 | ``` 14 | ██████╗███████╗██████╗ ██╗ ██╗██╗ ███████╗ 15 | ██╔════╝██╔════╝██╔══██╗██║ ██║██║ ██╔════╝ 16 | ██║ █████╗ ██████╔╝██║ ██║██║ █████╗ 17 | ██║ ██╔══╝ ██╔══██╗██║ ██║██║ ██╔══╝ 18 | ╚██████╗███████╗██║ ██║╚██████╔╝███████╗███████╗ 19 | ╚═════╝╚══════╝╚═╝ ╚═╝ ╚═════╝ ╚══════╝╚══════╝ 20 | ``` 21 | 22 | 23 | 24 | 25 | 26 | We train and release "Cerule", a tiny yet powerful Vision Lanuage Model based on the newly released Google's [Gemma-2b](https://huggingface.co/google/gemma-2b) and Google's [SigLIP](https://huggingface.co/google/siglip-so400m-patch14-384). 27 | 28 | 29 | The training setup was `4xA100's 80GB` and took ~6 hours to pretrain and ~13 hours to finetune. We modify and adapt the training code from [Bunny](https://github.com/BAAI-DCAI/Bunny). 30 | 31 | --- 32 | | Image | Example | 33 | |-------|---------| 34 | | ![astronaut](examples/astronaut.png) | **Describe the image**
The image is a playful and surreal depiction of a man in a space suit, sitting on a chair and holding a green beer bottle. The man is wearing a white space suit, complete with a helmet and gloves. His feet are clad in black and white shoes, and he is placed on a sandy surface. The background features a large, blue planet, with a moon and a star visible in the sky. | 35 | | ![mario](examples/mario.png) | **Who are the characters in the image?**
The image features three characters, two of them are Mario and Luigi, and the third one is Yoshi.

**Describe the actions of the characters**
The Mario and Luigi characters are holding their arms out, as if they are waving. Yoshi is standing on its own, with its arms folded. | 36 | | ![extreme_ironing](examples/extreme_ironing.jpg) | **What's funny about this image?**
The image is quite humorous as it depicts a man ironing clothes on the back of a yellow taxi cab. This is not a typical sight you'd expect to see in everyday life. | 37 | --- 38 | 39 | 40 | ## Training 41 | 42 | Before running the training, you need to install the following dependencies: 43 | 44 | * Create a conda env: 45 | ``` 46 | conda create -n cerule python=3.10 47 | conda activate cerule 48 | ``` 49 | * Basic requirements 50 | ``` 51 | pip install --upgrade pip 52 | pip install transformers 53 | pip install torch torchvision xformers --index-url https://download.pytorch.org/whl/cu118 54 | ``` 55 | 56 | * Instal Apex. Please install from source, as the package on pypi is not related to this. 57 | ``` 58 | pip install ninja 59 | git clone https://github.com/NVIDIA/apex 60 | cd apex 61 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./ 62 | # https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#installation-and-features 63 | ``` 64 | * Install flash-attention 65 | ``` 66 | pip install packaging 67 | pip install flash-attn --no-build-isolation 68 | ``` 69 | * Install Cerule and other requirements 70 | ``` 71 | git clone https://github.com/Tensoic-AI/Cerule 72 | cd Cerule 73 | pip install -e . 74 | ``` 75 | 76 | ### Pretrain 77 | 78 | * Data preparation 79 | We use the following Dataset prepared by the amazing folks at [Beijing Academy of Artificial Intelligence](https://huggingface.co/BAAI) 80 | The dataset is available [here](https://www.modelscope.cn/datasets/BoyaWu10/Bunny-v1.0-data). 81 | 82 | Pretrain Dataset format: 83 | ``` 84 | { 85 | "conversations": [ 86 | { 87 | "from": "human", 88 | "value": "\nProvide a brief description of the given image." 89 | }, 90 | { 91 | "from": "gpt", 92 | "value": "A set of three chrome and bubble glass table lamp bases. H.50cm - Image 4 of 10" 93 | } 94 | ], 95 | "id": "0006418798", 96 | "image": "0006418798.jpg" 97 | }, 98 | ``` 99 | 100 | * Run 101 | 102 | Update `--model_name_or_path` and `--vision_tower` to the paths of the LLM and vision encoder, respectively. Update `MODEL_TYPE` and `OUTPUT_DIR` accordingly. 103 | 104 | ```shell 105 | sh script/train/pretrain.sh 106 | ``` 107 | 108 | ### Visual Instruction Tuning 109 | 110 | * Data preparation 111 | 112 | We also utilize Bunny-695K a modified version of [SVIT-mix-665K](https://arxiv.org/abs/2307.04087) for finetuning by BAAI. 113 | The dataset is available [here](https://www.modelscope.cn/datasets/BoyaWu10/Bunny-v1.0-data). 114 | 115 | * Run 116 | 117 | Update `--model_name_or_path` and `--vision_tower` to the paths of the LLM and vision encoder, respectively. Update `MODEL_TYPE`, `PRETRAIN_DIR` and `OUTPUT_DIR` accordingly. The global batch size is 128. 118 | 119 | ```shell 120 | # full-parameter tuning 121 | sh script/train/finetune_full.sh 122 | 123 | # LoRA tuning 124 | sh script/train/finetune_lora.sh 125 | ``` 126 | 127 | 128 | ## Inference 129 | #### For a CLI based inference: 130 | ``` 131 | python3 -m cerule.serve.cli \ 132 | --model-path Tensoic/Cerule-v0.1 \ 133 | --image-file examples/astronaut.png 134 | ``` 135 | 136 | ## License 137 | Licensed under the Apache License, Version 2.0 http://www.apache.org/licenses/LICENSE-2.0. This file may not be copied, modified, or distributed except according to those terms. 138 | 139 | ## Contribution 140 | Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you, as defined in the Apache-2.0 license, shall be licensed as above, without any additional terms or conditions. 141 | 142 | ## Acknowledgements 143 | We sincerely thank the Amazing teams at Google, LLaVA, and BAAI without which this project would not have been possible! 144 | 145 | ## Star History 146 | 147 | 148 | 149 | 150 | 151 | Star History Chart 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | -------------------------------------------------------------------------------- /cerule/constants.py: -------------------------------------------------------------------------------- 1 | # Model Constants 2 | IGNORE_INDEX = -100 3 | IMAGE_TOKEN_INDEX = -200 4 | DEFAULT_IMAGE_TOKEN = "" 5 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 6 | LOGDIR = "gradio-logs" 7 | WORKER_HEART_BEAT_INTERVAL = 15 8 | IGNORE_INDEX = -100 9 | IMAGE_TOKEN_INDEX = -200 10 | DEFAULT_IMAGE_TOKEN = "" 11 | DEFAULT_IMAGE_PATCH_TOKEN = "" 12 | DEFAULT_IM_START_TOKEN = "" 13 | DEFAULT_IM_END_TOKEN = "" 14 | IMAGE_PLACEHOLDER = "" -------------------------------------------------------------------------------- /cerule/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from enum import auto, Enum 3 | from typing import List, Tuple 4 | import base64 5 | from io import BytesIO 6 | from PIL import Image 7 | 8 | 9 | class SeparatorStyle(Enum): 10 | """Different separator style.""" 11 | TWO = auto() 12 | PLAIN = auto() 13 | GEMMA = auto() 14 | 15 | 16 | @dataclasses.dataclass 17 | class Conversation: 18 | """A class that keeps all conversation history.""" 19 | system: str 20 | roles: List[str] 21 | messages: List[List[str]] 22 | offset: int 23 | sep_style: SeparatorStyle 24 | sep: str = "###" 25 | sep2: str = None 26 | version: str = "Unknown" 27 | 28 | skip_next: bool = False 29 | 30 | def get_prompt(self): 31 | messages = self.messages 32 | if len(messages) > 0 and type(messages[0][1]) is tuple: 33 | messages = self.messages.copy() 34 | init_role, init_msg = messages[0].copy() 35 | init_msg = init_msg[0].replace("", "").strip() 36 | if 'mmtag' in self.version: 37 | messages[0] = (init_role, init_msg) 38 | messages.insert(0, (self.roles[0], "")) 39 | messages.insert(1, (self.roles[1], "Received.")) 40 | else: 41 | messages[0] = (init_role, "\n" + init_msg) 42 | 43 | if self.sep_style == SeparatorStyle.TWO: 44 | seps = [self.sep, self.sep2] 45 | ret = self.system + seps[0] 46 | for i, (role, message) in enumerate(messages): 47 | if message: 48 | if type(message) is tuple: 49 | message, _, _ = message 50 | ret += role + ": " + message + seps[i % 2] 51 | else: 52 | ret += role + ":" 53 | 54 | elif self.sep_style == SeparatorStyle.PLAIN: 55 | seps = [self.sep, self.sep2] 56 | ret = self.system 57 | for i, (role, message) in enumerate(messages): 58 | if message: 59 | if type(message) is tuple: 60 | message, _, _ = message 61 | ret += message + seps[i % 2] 62 | else: 63 | ret += "" 64 | elif self.sep_style == SeparatorStyle.GEMMA: 65 | ret = "" 66 | for i, (role, message) in enumerate(messages): 67 | assert role == self.roles[i % 2], "Conversation should alternate user/assistant/user/assistant/..." 68 | if message: 69 | if type(message) is tuple: 70 | message, _, _ = message 71 | ret += role + message + self.sep 72 | else: 73 | ret += role 74 | else: 75 | raise ValueError(f"Invalid style: {self.sep_style}") 76 | 77 | return ret 78 | 79 | def append_message(self, role, message): 80 | self.messages.append([role, message]) 81 | 82 | def get_images(self, return_pil=False): 83 | images = [] 84 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 85 | if i % 2 == 0: 86 | if type(msg) is tuple: 87 | import base64 88 | from io import BytesIO 89 | from PIL import Image 90 | msg, image, image_process_mode = msg 91 | if image_process_mode == "Pad": 92 | def expand2square(pil_img, background_color=(122, 116, 104)): 93 | width, height = pil_img.size 94 | if width == height: 95 | return pil_img 96 | elif width > height: 97 | result = Image.new(pil_img.mode, (width, width), background_color) 98 | result.paste(pil_img, (0, (width - height) // 2)) 99 | return result 100 | else: 101 | result = Image.new(pil_img.mode, (height, height), background_color) 102 | result.paste(pil_img, ((height - width) // 2, 0)) 103 | return result 104 | 105 | image = expand2square(image) 106 | elif image_process_mode in ["Default", "Crop"]: 107 | pass 108 | elif image_process_mode == "Resize": 109 | image = image.resize((336, 336)) 110 | else: 111 | raise ValueError(f"Invalid image_process_mode: {image_process_mode}") 112 | max_hw, min_hw = max(image.size), min(image.size) 113 | aspect_ratio = max_hw / min_hw 114 | max_len, min_len = 800, 400 115 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 116 | longest_edge = int(shortest_edge * aspect_ratio) 117 | W, H = image.size 118 | if longest_edge != max(image.size): 119 | if H > W: 120 | H, W = longest_edge, shortest_edge 121 | else: 122 | H, W = shortest_edge, longest_edge 123 | image = image.resize((W, H)) 124 | if return_pil: 125 | images.append(image) 126 | else: 127 | buffered = BytesIO() 128 | image.save(buffered, format="PNG") 129 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 130 | images.append(img_b64_str) 131 | return images 132 | 133 | def to_gradio_chatbot(self): 134 | ret = [] 135 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 136 | if i % 2 == 0: 137 | if type(msg) is tuple: 138 | import base64 139 | from io import BytesIO 140 | msg, image, image_process_mode = msg 141 | max_hw, min_hw = max(image.size), min(image.size) 142 | aspect_ratio = max_hw / min_hw 143 | max_len, min_len = 800, 400 144 | shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) 145 | longest_edge = int(shortest_edge * aspect_ratio) 146 | W, H = image.size 147 | if H > W: 148 | H, W = longest_edge, shortest_edge 149 | else: 150 | H, W = shortest_edge, longest_edge 151 | image = image.resize((W, H)) 152 | buffered = BytesIO() 153 | image.save(buffered, format="JPEG") 154 | img_b64_str = base64.b64encode(buffered.getvalue()).decode() 155 | img_str = f'user upload image' 156 | msg = img_str + msg.replace('', '').strip() 157 | ret.append([msg, None]) 158 | else: 159 | ret.append([msg, None]) 160 | else: 161 | ret[-1][-1] = msg 162 | return ret 163 | 164 | def copy(self): 165 | return Conversation( 166 | system=self.system, 167 | roles=self.roles, 168 | messages=[[x, y] for x, y in self.messages], 169 | offset=self.offset, 170 | sep_style=self.sep_style, 171 | sep=self.sep, 172 | sep2=self.sep2, 173 | version=self.version) 174 | 175 | def dict(self): 176 | if len(self.get_images()) > 0: 177 | return { 178 | "system": self.system, 179 | "roles": self.roles, 180 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 181 | "offset": self.offset, 182 | "sep": self.sep, 183 | "sep2": self.sep2, 184 | } 185 | return { 186 | "system": self.system, 187 | "roles": self.roles, 188 | "messages": self.messages, 189 | "offset": self.offset, 190 | "sep": self.sep, 191 | "sep2": self.sep2, 192 | } 193 | 194 | conv_plain = Conversation( 195 | system="", 196 | roles=("", ""), 197 | messages=( 198 | ), 199 | offset=0, 200 | sep_style=SeparatorStyle.PLAIN, 201 | sep="\n", 202 | ) 203 | 204 | conv_gemma_instruct = Conversation( 205 | system="", 206 | roles=("user\n", "model\n"), 207 | version="gemma", 208 | messages=(), 209 | offset=0, 210 | sep_style=SeparatorStyle.GEMMA, 211 | sep="\n" 212 | ) 213 | 214 | default_conversation = conv_gemma_instruct 215 | conv_templates = { 216 | "default": conv_gemma_instruct, 217 | "plain": conv_plain, 218 | "gemma_instruct": conv_gemma_instruct, 219 | } 220 | 221 | if __name__ == "__main__": 222 | print(default_conversation.get_prompt()) 223 | -------------------------------------------------------------------------------- /cerule/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.cerule_gemma import CeruleGemmaForCausalLM, CeruleGemmaConfig 2 | -------------------------------------------------------------------------------- /cerule/model/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | 4 | warnings.filterwarnings("ignore") 5 | 6 | from transformers import AutoTokenizer, AutoConfig, BitsAndBytesConfig 7 | import torch 8 | from cerule.model import * 9 | 10 | # can add more models. just load, as done below for gemma 11 | 12 | def load_pretrained_model(model_path, model_base, model_name, model_type, load_8bit=False, load_4bit=False, 13 | device_map="auto", device="cuda", use_flash_attn=False, **kwargs): 14 | if model_type not in {'gemma'}: 15 | raise ValueError(f"Unknown Model Type {model_type}") 16 | 17 | kwargs = {"device_map": device_map, **kwargs} 18 | 19 | if device != "cuda": 20 | kwargs['device_map'] = {"": device} 21 | 22 | if load_8bit: 23 | kwargs['load_in_8bit'] = True 24 | elif load_4bit: 25 | kwargs['load_in_4bit'] = True 26 | kwargs['quantization_config'] = BitsAndBytesConfig( 27 | load_in_4bit=True, 28 | bnb_4bit_compute_dtype=torch.float16, 29 | bnb_4bit_use_double_quant=True, 30 | bnb_4bit_quant_type='nf4' 31 | ) 32 | else: 33 | kwargs['torch_dtype'] = torch.float16 34 | 35 | # Load Cerule model 36 | if 'lora' in model_name.lower() and model_base is None: 37 | warnings.warn( 38 | 'There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument.') 39 | if 'lora' in model_name.lower() and model_base is not None: 40 | lora_cfg_pretrained = AutoConfig.from_pretrained(model_path) 41 | 42 | print('Loading Cerule from base model...') 43 | if model_type == 'gemma': 44 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, trust_remote_code=True) 45 | model = CeruleGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, 46 | config=lora_cfg_pretrained, **kwargs) 47 | 48 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 49 | if model.lm_head.weight.shape[0] != token_num: 50 | model.lm_head.weight = torch.nn.Parameter( 51 | torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 52 | model.model.embed_tokens.weight = torch.nn.Parameter( 53 | torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 54 | 55 | print('Loading additional Cerule weights...') 56 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 57 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 58 | else: 59 | # this is probably from HF Hub 60 | from huggingface_hub import hf_hub_download 61 | def load_from_hf(repo_id, filename, subfolder=None): 62 | cache_file = hf_hub_download( 63 | repo_id=repo_id, 64 | filename=filename, 65 | subfolder=subfolder) 66 | return torch.load(cache_file, map_location='cpu') 67 | 68 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 69 | 70 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in 71 | non_lora_trainables.items()} 72 | if any(k.startswith('model.model.') for k in non_lora_trainables): 73 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in 74 | non_lora_trainables.items()} 75 | model.load_state_dict(non_lora_trainables, strict=False) 76 | 77 | from peft import PeftModel 78 | print('Loading LoRA weights...') 79 | model = PeftModel.from_pretrained(model, model_path) 80 | print('Merging LoRA weights...') 81 | model = model.merge_and_unload() 82 | print('Model is loaded...') 83 | elif model_base is not None: 84 | print('Loading Cerule from base model...') 85 | 86 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 87 | if model_type == 'gemma': 88 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False, trust_remote_code=True) 89 | model = CeruleGemmaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, 90 | config=cfg_pretrained, **kwargs) 91 | 92 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 93 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 94 | model.load_state_dict(mm_projector_weights, strict=False) 95 | else: 96 | if model_type == 'gemma': 97 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 98 | model = CeruleGemmaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 99 | 100 | model.resize_token_embeddings(len(tokenizer)) 101 | 102 | vision_tower = model.get_vision_tower() 103 | if not vision_tower.is_loaded: 104 | vision_tower.load_model() 105 | vision_tower.to(device=device, dtype=torch.float16) 106 | image_processor = vision_tower.image_processor 107 | 108 | if hasattr(model.config, "max_sequence_length"): 109 | context_len = model.config.max_sequence_length 110 | else: 111 | context_len = 2048 112 | 113 | if model.config.pad_token_id is None: 114 | model.config.pad_token_id = model.config.eos_token_id 115 | 116 | return tokenizer, model, image_processor, context_len 117 | -------------------------------------------------------------------------------- /cerule/model/cerule_arch.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | from .multimodal_encoder.builder import build_vision_tower 6 | from .multimodal_projector.builder import build_vision_projector 7 | 8 | from cerule.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX 9 | 10 | 11 | class CeruleMetaModel: 12 | 13 | def __init__(self, config): 14 | super(CeruleMetaModel, self).__init__(config) 15 | 16 | if hasattr(config, "mm_vision_tower"): 17 | self.vision_tower = build_vision_tower(config, delay_load=True) 18 | self.mm_projector = build_vision_projector(config) 19 | 20 | def get_vision_tower(self): 21 | vision_tower = getattr(self, 'vision_tower', None) 22 | if type(vision_tower) is list: 23 | vision_tower = vision_tower[0] 24 | return vision_tower 25 | 26 | def initialize_vision_modules(self, model_args): 27 | vision_tower = model_args.vision_tower 28 | 29 | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter 30 | 31 | self.config.mm_vision_tower = vision_tower 32 | 33 | if self.get_vision_tower() is None: 34 | vision_tower = build_vision_tower(model_args) 35 | self.vision_tower = vision_tower 36 | else: 37 | vision_tower = self.vision_tower 38 | vision_tower.load_model() 39 | 40 | self.config.use_mm_proj = True 41 | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type') 42 | self.config.mm_hidden_size = vision_tower.hidden_size 43 | 44 | if getattr(self, 'mm_projector', None) is None: 45 | self.mm_projector = build_vision_projector(self.config) 46 | else: 47 | # In case it is frozen by LoRA 48 | for p in self.mm_projector.parameters(): 49 | p.requires_grad = True 50 | 51 | if pretrain_mm_mlp_adapter is not None: 52 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') 53 | 54 | def get_w(weights, keyword): 55 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 56 | 57 | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) 58 | 59 | 60 | class CeruleMetaForCausalLM(ABC): 61 | 62 | @abstractmethod 63 | def get_model(self): 64 | pass 65 | 66 | def get_vision_tower(self): 67 | return self.get_model().get_vision_tower() 68 | 69 | def encode_images(self, images): 70 | image_features = self.get_model().get_vision_tower()(images) 71 | image_features = self.get_model().mm_projector(image_features) 72 | return image_features 73 | 74 | def prepare_inputs_labels_for_multimodal( 75 | self, input_ids, position_ids, attention_mask, past_key_values, labels, images 76 | ): 77 | vision_tower = self.get_vision_tower() 78 | if vision_tower is None or images is None or input_ids.shape[1] == 1: 79 | if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[ 80 | 1] == 1: 81 | target_shape = past_key_values[-1][-1].shape[-2] + 1 82 | attention_mask = torch.cat((attention_mask, torch.ones( 83 | (attention_mask.shape[0], target_shape - attention_mask.shape[1]), 84 | dtype=attention_mask.dtype, 85 | device=attention_mask.device 86 | )), dim=1) 87 | position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 88 | return input_ids, position_ids, attention_mask, past_key_values, None, labels 89 | 90 | if type(images) is list or images.ndim == 5: 91 | concat_images = torch.cat([image for image in images], dim=0) 92 | image_features = self.encode_images(concat_images) 93 | split_sizes = [image.shape[0] for image in images] 94 | image_features = torch.split(image_features, split_sizes, dim=0) 95 | image_features = [x.flatten(0, 1).to(self.device) for x in image_features] 96 | else: 97 | image_features = self.encode_images(images).to(self.device) 98 | 99 | # Let's just add dummy tensors if they do not exist, 100 | # it is a headache to deal with None all the time. 101 | # But it is not ideal, and if you have a better idea, 102 | # please open an issue / submit a PR, thanks. 103 | _labels = labels 104 | _position_ids = position_ids 105 | _attention_mask = attention_mask 106 | if attention_mask is None: 107 | attention_mask = torch.ones_like(input_ids, dtype=torch.bool) 108 | else: 109 | attention_mask = attention_mask.bool() 110 | if position_ids is None: 111 | position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) 112 | if labels is None: 113 | labels = torch.full_like(input_ids, IGNORE_INDEX) 114 | 115 | # remove the padding using attention_mask -- TODO: double check 116 | input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in 117 | zip(input_ids, attention_mask)] 118 | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] 119 | 120 | new_input_embeds = [] 121 | new_labels = [] 122 | cur_image_idx = 0 123 | for batch_idx, cur_input_ids in enumerate(input_ids): 124 | num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() 125 | if num_images == 0: 126 | cur_image_features = image_features[cur_image_idx] 127 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) 128 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0) 129 | new_input_embeds.append(cur_input_embeds) 130 | new_labels.append(labels[batch_idx]) 131 | cur_image_idx += 1 132 | continue 133 | 134 | image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [ 135 | cur_input_ids.shape[0]] 136 | cur_input_ids_noim = [] 137 | cur_labels = labels[batch_idx] 138 | cur_labels_noim = [] 139 | for i in range(len(image_token_indices) - 1): 140 | cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]]) 141 | cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]]) 142 | split_sizes = [x.shape[0] for x in cur_labels_noim] 143 | cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_noim)) 144 | cur_input_embeds_no_im = torch.split(cur_input_embeds, split_sizes, dim=0) 145 | cur_new_input_embeds = [] 146 | cur_new_labels = [] 147 | 148 | for i in range(num_images + 1): 149 | cur_new_input_embeds.append(cur_input_embeds_no_im[i]) 150 | cur_new_labels.append(cur_labels_noim[i]) 151 | if i < num_images: 152 | cur_image_features = image_features[cur_image_idx] 153 | cur_image_idx += 1 154 | cur_new_input_embeds.append(cur_image_features) 155 | cur_new_labels.append( 156 | torch.full((cur_image_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, 157 | dtype=cur_labels.dtype)) 158 | 159 | cur_new_input_embeds = torch.cat(cur_new_input_embeds) 160 | cur_new_labels = torch.cat(cur_new_labels) 161 | 162 | new_input_embeds.append(cur_new_input_embeds) 163 | new_labels.append(cur_new_labels) 164 | 165 | # Truncate sequences to max length as image embeddings can make the sequence longer 166 | tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) 167 | if tokenizer_model_max_length is not None: 168 | new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] 169 | new_labels = [x[:tokenizer_model_max_length] for x in new_labels] 170 | 171 | # Combine them 172 | max_len = max(x.shape[0] for x in new_input_embeds) 173 | batch_size = len(new_input_embeds) 174 | 175 | new_input_embeds_padded = [] 176 | new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, 177 | device=new_labels[0].device) 178 | attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) 179 | position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) 180 | 181 | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): 182 | cur_len = cur_new_embed.shape[0] 183 | if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": 184 | new_input_embeds_padded.append(torch.cat(( 185 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, 186 | device=cur_new_embed.device), 187 | cur_new_embed 188 | ), dim=0)) 189 | if cur_len > 0: 190 | new_labels_padded[i, -cur_len:] = cur_new_labels 191 | attention_mask[i, -cur_len:] = True 192 | position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, 193 | device=position_ids.device) 194 | else: 195 | new_input_embeds_padded.append(torch.cat(( 196 | cur_new_embed, 197 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, 198 | device=cur_new_embed.device) 199 | ), dim=0)) 200 | if cur_len > 0: 201 | new_labels_padded[i, :cur_len] = cur_new_labels 202 | attention_mask[i, :cur_len] = True 203 | position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, 204 | device=position_ids.device) 205 | 206 | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) 207 | 208 | if _labels is None: 209 | new_labels = None 210 | else: 211 | new_labels = new_labels_padded 212 | 213 | if _attention_mask is None: 214 | attention_mask = None 215 | else: 216 | attention_mask = attention_mask.to(dtype=_attention_mask.dtype) 217 | 218 | if _position_ids is None: 219 | position_ids = None 220 | 221 | return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels 222 | -------------------------------------------------------------------------------- /cerule/model/language_model/cerule_gemma.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from transformers import AutoConfig, AutoModelForCausalLM, GemmaConfig, GemmaModel, GemmaForCausalLM 6 | 7 | from transformers.modeling_outputs import CausalLMOutputWithPast 8 | 9 | from ..cerule_arch import CeruleMetaModel, CeruleMetaForCausalLM 10 | 11 | 12 | class CeruleGemmaConfig(GemmaConfig): 13 | model_type = "cerule-gemma" 14 | 15 | 16 | class CeruleGemmaModel(CeruleMetaModel, GemmaModel): 17 | config_class = CeruleGemmaConfig 18 | 19 | def __init__(self, config: GemmaConfig): 20 | super(CeruleGemmaModel, self).__init__(config) 21 | 22 | 23 | class CeruleGemmaForCausalLM(GemmaForCausalLM, CeruleMetaForCausalLM): 24 | config_class = CeruleGemmaConfig 25 | 26 | def __init__(self, config): 27 | super(GemmaForCausalLM, self).__init__(config) 28 | self.model = CeruleGemmaModel(config) 29 | self.vocab_size = config.vocab_size 30 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 31 | 32 | # Initialize weights and apply final processing 33 | self.post_init() 34 | 35 | def get_model(self): 36 | return self.model 37 | 38 | def forward( 39 | self, 40 | input_ids: torch.LongTensor = None, 41 | attention_mask: Optional[torch.Tensor] = None, 42 | position_ids: Optional[torch.LongTensor] = None, 43 | past_key_values: Optional[List[torch.FloatTensor]] = None, 44 | inputs_embeds: Optional[torch.FloatTensor] = None, 45 | labels: Optional[torch.LongTensor] = None, 46 | use_cache: Optional[bool] = None, 47 | output_attentions: Optional[bool] = None, 48 | output_hidden_states: Optional[bool] = None, 49 | images: Optional[torch.FloatTensor] = None, 50 | return_dict: Optional[bool] = None, 51 | cache_position=None, 52 | ) -> Union[Tuple, CausalLMOutputWithPast]: 53 | 54 | if inputs_embeds is None: 55 | ( 56 | input_ids, 57 | position_ids, 58 | attention_mask, 59 | past_key_values, 60 | inputs_embeds, 61 | labels 62 | ) = self.prepare_inputs_labels_for_multimodal( 63 | input_ids, 64 | position_ids, 65 | attention_mask, 66 | past_key_values, 67 | labels, 68 | images 69 | ) 70 | 71 | return super().forward( 72 | input_ids=input_ids, 73 | attention_mask=attention_mask, 74 | position_ids=position_ids, 75 | past_key_values=past_key_values, 76 | inputs_embeds=inputs_embeds, 77 | labels=labels, 78 | use_cache=use_cache, 79 | output_attentions=output_attentions, 80 | output_hidden_states=output_hidden_states, 81 | return_dict=return_dict 82 | ) 83 | 84 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, attention_mask=None, 85 | **kwargs): 86 | images = kwargs.pop("images", None) 87 | 88 | _inputs = super().prepare_inputs_for_generation( 89 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, attention_mask=attention_mask, 90 | **kwargs 91 | ) 92 | 93 | if images is not None: 94 | _inputs['images'] = images 95 | return _inputs 96 | 97 | 98 | AutoConfig.register("cerule-gemma", CeruleGemmaConfig) 99 | AutoModelForCausalLM.register(CeruleGemmaConfig, CeruleGemmaForCausalLM) 100 | -------------------------------------------------------------------------------- /cerule/model/language_model/gemma/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import TYPE_CHECKING 15 | 16 | from transformers.utils import ( 17 | OptionalDependencyNotAvailable, 18 | _LazyModule, 19 | is_flax_available, 20 | is_sentencepiece_available, 21 | is_tokenizers_available, 22 | is_torch_available, 23 | ) 24 | 25 | 26 | _import_structure = { 27 | "configuration_gemma": ["GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP", "GemmaConfig"], 28 | } 29 | 30 | try: 31 | if not is_sentencepiece_available(): 32 | raise OptionalDependencyNotAvailable() 33 | except OptionalDependencyNotAvailable: 34 | pass 35 | else: 36 | _import_structure["tokenization_gemma"] = ["GemmaTokenizer"] 37 | 38 | try: 39 | if not is_tokenizers_available(): 40 | raise OptionalDependencyNotAvailable() 41 | except OptionalDependencyNotAvailable: 42 | pass 43 | else: 44 | _import_structure["tokenization_gemma_fast"] = ["GemmaTokenizerFast"] 45 | 46 | 47 | try: 48 | if not is_torch_available(): 49 | raise OptionalDependencyNotAvailable() 50 | except OptionalDependencyNotAvailable: 51 | pass 52 | else: 53 | _import_structure["modeling_gemma"] = [ 54 | "GemmaForCausalLM", 55 | "GemmaModel", 56 | "GemmaPreTrainedModel", 57 | "GemmaForSequenceClassification", 58 | ] 59 | 60 | try: 61 | if not is_flax_available(): 62 | raise OptionalDependencyNotAvailable() 63 | except OptionalDependencyNotAvailable: 64 | pass 65 | else: 66 | _import_structure["modeling_flax_gemma"] = [ 67 | "FlaxGemmaForCausalLM", 68 | "FlaxGemmaModel", 69 | "FlaxGemmaPreTrainedModel", 70 | ] 71 | 72 | 73 | if TYPE_CHECKING: 74 | from .configuration_gemma import GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP, GemmaConfig 75 | 76 | try: 77 | if not is_sentencepiece_available(): 78 | raise OptionalDependencyNotAvailable() 79 | except OptionalDependencyNotAvailable: 80 | pass 81 | else: 82 | from .tokenization_gemma import GemmaTokenizer 83 | 84 | try: 85 | if not is_tokenizers_available(): 86 | raise OptionalDependencyNotAvailable() 87 | except OptionalDependencyNotAvailable: 88 | pass 89 | else: 90 | from .tokenization_gemma_fast import GemmaTokenizerFast 91 | 92 | try: 93 | if not is_torch_available(): 94 | raise OptionalDependencyNotAvailable() 95 | except OptionalDependencyNotAvailable: 96 | pass 97 | else: 98 | from .modeling_gemma import ( 99 | GemmaForCausalLM, 100 | GemmaForSequenceClassification, 101 | GemmaModel, 102 | GemmaPreTrainedModel, 103 | ) 104 | 105 | try: 106 | if not is_flax_available(): 107 | raise OptionalDependencyNotAvailable() 108 | except OptionalDependencyNotAvailable: 109 | pass 110 | else: 111 | from .modeling_flax_gemma import ( 112 | FlaxGemmaForCausalLM, 113 | FlaxGemmaModel, 114 | FlaxGemmaPreTrainedModel, 115 | ) 116 | 117 | 118 | else: 119 | import sys 120 | 121 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) -------------------------------------------------------------------------------- /cerule/model/language_model/gemma/configuration_gemma.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Gemma model configuration""" 16 | 17 | from transformers.configuration_utils import PretrainedConfig 18 | from transformers.utils import logging 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | GEMMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {} 24 | 25 | 26 | class GemmaConfig(PretrainedConfig): 27 | r""" 28 | This is the configuration class to store the configuration of a [`GemmaModel`]. It is used to instantiate an Gemma 29 | model according to the specified arguments, defining the model architecture. Instantiating a configuration with the 30 | defaults will yield a similar configuration to that of the Gemma-7B. 31 | 32 | e.g. [google/gemma-7b](https://huggingface.co/google/gemma-7b) 33 | 34 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 35 | documentation from [`PretrainedConfig`] for more information. 36 | 37 | 38 | Args: 39 | vocab_size (`int`, *optional*, defaults to 256000): 40 | Vocabulary size of the Gemma model. Defines the number of different tokens that can be represented by the 41 | `inputs_ids` passed when calling [`GemmaModel`] 42 | hidden_size (`int`, *optional*, defaults to 3072): 43 | Dimension of the hidden representations. 44 | intermediate_size (`int`, *optional*, defaults to 24576): 45 | Dimension of the MLP representations. 46 | num_hidden_layers (`int`, *optional*, defaults to 28): 47 | Number of hidden layers in the Transformer decoder. 48 | num_attention_heads (`int`, *optional*, defaults to 16): 49 | Number of attention heads for each attention layer in the Transformer decoder. 50 | num_key_value_heads (`int`, *optional*, defaults to 16): 51 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 52 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 53 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 54 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 55 | by meanpooling all the original heads within that group. For more details checkout [this 56 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to 57 | `num_attention_heads`. 58 | head_dim (`int`, *optional*, defaults to 256): 59 | The attention head dimension. 60 | hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): 61 | The non-linear activation function (function or string) in the decoder. 62 | max_position_embeddings (`int`, *optional*, defaults to 8192): 63 | The maximum sequence length that this model might ever be used with. 64 | initializer_range (`float`, *optional*, defaults to 0.02): 65 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 66 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 67 | The epsilon used by the rms normalization layers. 68 | use_cache (`bool`, *optional*, defaults to `True`): 69 | Whether or not the model should return the last key/values attentions (not used by all models). Only 70 | relevant if `config.is_decoder=True`. 71 | pad_token_id (`int`, *optional*, defaults to 0): 72 | Padding token id. 73 | eos_token_id (`int`, *optional*, defaults to 1): 74 | End of stream token id. 75 | bos_token_id (`int`, *optional*, defaults to 2): 76 | Beginning of stream token id. 77 | tie_word_embeddings (`bool`, *optional*, defaults to `True`): 78 | Whether to tie weight embeddings 79 | rope_theta (`float`, *optional*, defaults to 10000.0): 80 | The base period of the RoPE embeddings. 81 | attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): 82 | Whether to use a bias in the query, key, value and output projection layers during self-attention. 83 | attention_dropout (`float`, *optional*, defaults to 0.0): 84 | The dropout ratio for the attention probabilities. 85 | 86 | ```python 87 | >>> from transformers import GemmaModel, GemmaConfig 88 | 89 | >>> # Initializing a Gemma gemma-7b style configuration 90 | >>> configuration = GemmaConfig() 91 | 92 | >>> # Initializing a model from the gemma-7b style configuration 93 | >>> model = GemmaModel(configuration) 94 | 95 | >>> # Accessing the model configuration 96 | >>> configuration = model.config 97 | ```""" 98 | 99 | model_type = "gemma" 100 | keys_to_ignore_at_inference = ["past_key_values"] 101 | 102 | def __init__( 103 | self, 104 | vocab_size=256000, 105 | hidden_size=3072, 106 | intermediate_size=24576, 107 | num_hidden_layers=28, 108 | num_attention_heads=16, 109 | num_key_value_heads=16, 110 | head_dim=256, 111 | hidden_act="gelu", 112 | max_position_embeddings=8192, 113 | initializer_range=0.02, 114 | rms_norm_eps=1e-6, 115 | use_cache=True, 116 | pad_token_id=0, 117 | eos_token_id=1, 118 | bos_token_id=2, 119 | tie_word_embeddings=True, 120 | rope_theta=10000.0, 121 | attention_bias=False, 122 | attention_dropout=0.0, 123 | **kwargs, 124 | ): 125 | self.vocab_size = vocab_size 126 | self.max_position_embeddings = max_position_embeddings 127 | self.hidden_size = hidden_size 128 | self.intermediate_size = intermediate_size 129 | self.num_hidden_layers = num_hidden_layers 130 | self.num_attention_heads = num_attention_heads 131 | self.head_dim = head_dim 132 | self.num_key_value_heads = num_key_value_heads 133 | self.hidden_act = hidden_act 134 | self.initializer_range = initializer_range 135 | self.rms_norm_eps = rms_norm_eps 136 | self.use_cache = use_cache 137 | self.rope_theta = rope_theta 138 | self.attention_bias = attention_bias 139 | self.attention_dropout = attention_dropout 140 | 141 | super().__init__( 142 | pad_token_id=pad_token_id, 143 | bos_token_id=bos_token_id, 144 | eos_token_id=eos_token_id, 145 | tie_word_embeddings=tie_word_embeddings, 146 | **kwargs, 147 | ) -------------------------------------------------------------------------------- /cerule/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .eva_clip.eva_clip_encoder import EvaClipVisionTower 3 | from .siglip.siglip_encoder import SigLipVisionTower 4 | from .clip.clip_encoder import CLIPVisionTower 5 | 6 | 7 | def build_vision_tower(vision_tower_cfg, **kwargs): 8 | vision_tower = getattr(vision_tower_cfg, 'mm_vision_tower', getattr(vision_tower_cfg, 'vision_tower', None)) 9 | 10 | if 'sig' in vision_tower.lower(): 11 | return SigLipVisionTower(vision_tower, vision_tower_cfg=vision_tower_cfg, **kwargs) 12 | 13 | elif 'eva' in vision_tower.lower(): 14 | return EvaClipVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 15 | 16 | elif 'clip' in vision_tower.lower(): 17 | return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 18 | 19 | else: 20 | raise ValueError(f'Unknown vision tower: {vision_tower}') 21 | -------------------------------------------------------------------------------- /cerule/model/multimodal_encoder/clip/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower, args, delay_load=False): 9 | super().__init__() 10 | 11 | self.is_loaded = False 12 | 13 | self.vision_tower_name = vision_tower 14 | self.select_layer = -2 15 | 16 | if not delay_load: 17 | self.load_model() 18 | else: 19 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 20 | 21 | def load_model(self): 22 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 23 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name) 24 | self.vision_tower.requires_grad_(False) 25 | 26 | self.is_loaded = True 27 | 28 | def feature_select(self, image_forward_outs): 29 | image_features = image_forward_outs.hidden_states[self.select_layer] 30 | 31 | image_features = image_features[:, 1:] 32 | 33 | return image_features 34 | 35 | @torch.no_grad() 36 | def forward(self, images): 37 | if type(images) is list: 38 | image_features = [] 39 | for image in images: 40 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), 41 | output_hidden_states=True) 42 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 43 | image_features.append(image_feature) 44 | else: 45 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), 46 | output_hidden_states=True) 47 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 48 | 49 | return image_features 50 | 51 | @property 52 | def dummy_feature(self): 53 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 54 | 55 | @property 56 | def dtype(self): 57 | return self.vision_tower.dtype 58 | 59 | @property 60 | def device(self): 61 | return self.vision_tower.device 62 | 63 | @property 64 | def config(self): 65 | if self.is_loaded: 66 | return self.vision_tower.config 67 | else: 68 | return self.cfg_only 69 | 70 | @property 71 | def hidden_size(self): 72 | return self.config.hidden_size 73 | 74 | @property 75 | def num_patches(self): 76 | return (self.config.image_size // self.config.patch_size) ** 2 77 | -------------------------------------------------------------------------------- /cerule/model/multimodal_encoder/eva_clip/eva_clip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .eva_clip_processors import EvaClipImageTrainProcessor 5 | from .eva_vit import Eva2LargePlusEncoder 6 | 7 | 8 | class EvaClipVisionTower(nn.Module): 9 | def __init__(self, vision_tower, args, delay_load=False): 10 | super().__init__() 11 | 12 | self.is_loaded = False 13 | 14 | self.vision_tower_path = vision_tower 15 | self.config = VisionTowerConfig() 16 | 17 | if not delay_load: 18 | self.load_model() 19 | else: 20 | self.cfg_only = self.config 21 | 22 | def load_model(self): 23 | self.image_processor = EvaClipImageTrainProcessor(self.config.image_size) 24 | self.vision_tower = Eva2LargePlusEncoder(self.vision_tower_path) 25 | self.vision_tower.requires_grad_(False) 26 | 27 | self.is_loaded = True 28 | 29 | @torch.no_grad() 30 | def forward(self, images): 31 | if type(images) is list: 32 | image_features = [] 33 | for image in images: 34 | image_feature = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)).to( 35 | image.dtype) 36 | image_features.append(image_feature) 37 | else: 38 | image_features = self.vision_tower(images.to(device=self.device, dtype=self.dtype)).to(images.dtype) 39 | 40 | return image_features 41 | 42 | @property 43 | def dtype(self): 44 | return self.vision_tower.dtype 45 | 46 | @property 47 | def device(self): 48 | return self.vision_tower.device 49 | 50 | @property 51 | def hidden_size(self): 52 | return self.config.hidden_size 53 | 54 | @property 55 | def num_patches(self): 56 | return (self.config.image_size // self.config.patch_size) ** 2 57 | 58 | 59 | class VisionTowerConfig(): 60 | def __init__(self): 61 | self.image_size = 336 62 | self.patch_size = 14 63 | self.hidden_size = 1024 64 | -------------------------------------------------------------------------------- /cerule/model/multimodal_encoder/eva_clip/eva_clip_processors.py: -------------------------------------------------------------------------------- 1 | ''' 2 | # Adapted from https://github.com/baaivision/EVA/tree/master/EVA-CLIP 3 | ''' 4 | 5 | from torchvision import transforms 6 | from torchvision.transforms.functional import InterpolationMode 7 | from transformers.image_processing_utils import BatchFeature 8 | from PIL import Image 9 | from transformers.image_transforms import convert_to_rgb 10 | 11 | 12 | class BaseProcessor: 13 | def __init__(self): 14 | self.transform = lambda x: x 15 | return 16 | 17 | def __call__(self, item): 18 | return self.transform(item) 19 | 20 | 21 | class EvaClipImageBaseProcessor(BaseProcessor): 22 | def __init__(self, mean=None, std=None): 23 | self.mean = (0.48145466, 0.4578275, 0.40821073) if mean is None else mean 24 | self.std = (0.26862954, 0.26130258, 0.27577711) if std is None else std 25 | 26 | self.normalize = transforms.Normalize(self.mean, self.std) 27 | 28 | @property 29 | def image_mean(self): 30 | return self.mean 31 | 32 | 33 | class EvaClipImageTrainProcessor(EvaClipImageBaseProcessor): 34 | def __init__(self, image_size=224, mean=None, std=None, min_scale=0.5, max_scale=1.0): 35 | super().__init__(mean=mean, std=std) 36 | 37 | self.transform = transforms.Compose( 38 | [ 39 | convert_to_rgb, 40 | transforms.Resize( 41 | image_size, 42 | interpolation=InterpolationMode.BICUBIC, 43 | ), 44 | transforms.CenterCrop(image_size), 45 | transforms.ToTensor(), 46 | self.normalize, 47 | ] 48 | ) 49 | 50 | self.image_size = image_size 51 | 52 | def preprocess(self, images, return_tensors): 53 | if isinstance(images, Image.Image): 54 | images = [images] 55 | else: 56 | assert isinstance(images, list) 57 | 58 | transformed_images = [self.transform(image).numpy() for image in images] 59 | data = {"pixel_values": transformed_images} 60 | 61 | return BatchFeature(data=data, tensor_type=return_tensors) 62 | 63 | def __call__(self, item): 64 | return self.transform(item) 65 | 66 | @property 67 | def crop_size(self): 68 | return {'height': self.image_size, 'width': self.image_size} 69 | -------------------------------------------------------------------------------- /cerule/model/multimodal_encoder/siglip/siglip_encoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | # Adapted from https://huggingface.co/MILVLG/imp-v1-3b/blob/main/vision_encoder.py 3 | ''' 4 | 5 | from typing import Optional, Tuple, Union, Dict 6 | from dataclasses import dataclass 7 | from functools import partial, reduce 8 | from PIL import Image 9 | import torch 10 | import torch.utils.checkpoint 11 | from torch import nn 12 | import os 13 | from transformers.image_processing_utils import BatchFeature, get_size_dict 14 | from transformers.image_transforms import (convert_to_rgb, normalize, rescale, resize, to_channel_dimension_format, ) 15 | from transformers.image_utils import (ChannelDimension, PILImageResampling, to_numpy_array, ) 16 | from transformers.activations import ACT2FN 17 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling 18 | from transformers.modeling_utils import PreTrainedModel 19 | from transformers import PretrainedConfig 20 | from transformers.utils import ModelOutput 21 | 22 | 23 | class SigLipImageProcessor: 24 | def __init__(self, 25 | image_mean=(0.5, 0.5, 0.5), 26 | image_std=(0.5, 0.5, 0.5), 27 | size=(384, 384), 28 | crop_size: Dict[str, int] = None, 29 | resample=PILImageResampling.BICUBIC, 30 | rescale_factor=1 / 255, 31 | data_format=ChannelDimension.FIRST): 32 | crop_size = crop_size if crop_size is not None else {"height": 384, "width": 384} 33 | crop_size = get_size_dict(crop_size, default_to_square=True, param_name="crop_size") 34 | 35 | self.image_mean = image_mean 36 | self.image_std = image_std 37 | self.size = size 38 | self.resample = resample 39 | self.rescale_factor = rescale_factor 40 | self.data_format = data_format 41 | self.crop_size = crop_size 42 | 43 | def preprocess(self, images, return_tensors): 44 | if isinstance(images, Image.Image): 45 | images = [images] 46 | else: 47 | assert isinstance(images, list) 48 | 49 | transforms = [ 50 | convert_to_rgb, 51 | to_numpy_array, 52 | partial(resize, size=self.size, resample=self.resample, data_format=self.data_format), 53 | partial(rescale, scale=self.rescale_factor, data_format=self.data_format), 54 | partial(normalize, mean=self.image_mean, std=self.image_std, data_format=self.data_format), 55 | partial(to_channel_dimension_format, channel_dim=self.data_format, input_channel_dim=self.data_format), 56 | ] 57 | 58 | images = reduce(lambda x, f: [*map(f, x)], transforms, images) 59 | data = {"pixel_values": images} 60 | 61 | return BatchFeature(data=data, tensor_type=return_tensors) 62 | 63 | 64 | class SigLipVisionConfig(PretrainedConfig): 65 | model_type = "siglip_vision_model" 66 | 67 | def __init__( 68 | self, 69 | hidden_size=1152, 70 | image_mean=(0.5, 0.5, 0.5), 71 | intermediate_size=4304, 72 | num_hidden_layers=27, 73 | num_attention_heads=16, 74 | num_channels=3, 75 | image_size=384, 76 | patch_size=14, 77 | hidden_act="gelu_pytorch_tanh", 78 | layer_norm_eps=1e-6, 79 | attention_dropout=0.0, 80 | **kwargs, 81 | ): 82 | super().__init__(**kwargs) 83 | 84 | self.hidden_size = hidden_size 85 | self.intermediate_size = intermediate_size 86 | self.num_hidden_layers = num_hidden_layers 87 | self.num_attention_heads = num_attention_heads 88 | self.num_channels = num_channels 89 | self.patch_size = patch_size 90 | self.image_size = image_size 91 | self.attention_dropout = attention_dropout 92 | self.layer_norm_eps = layer_norm_eps 93 | self.hidden_act = hidden_act 94 | self.image_mean = image_mean 95 | 96 | @classmethod 97 | def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": 98 | cls._set_token_in_kwargs(kwargs) 99 | 100 | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) 101 | 102 | # get the vision config dict if we are loading from SigLipConfig 103 | if config_dict.get("model_type") == "siglip": 104 | config_dict = config_dict["vision_config"] 105 | 106 | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: 107 | logger.warning( 108 | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " 109 | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." 110 | ) 111 | 112 | return cls.from_dict(config_dict, **kwargs) 113 | 114 | 115 | @dataclass 116 | # Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->SigLip 117 | class SigLipVisionModelOutput(ModelOutput): 118 | """ 119 | Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. 120 | 121 | Args: 122 | image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): 123 | The image embeddings obtained by applying the projection layer to the pooler_output. 124 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 125 | Sequence of hidden-states at the output of the last layer of the model. 126 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 127 | Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + 128 | one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. 129 | 130 | Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. 131 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 132 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 133 | sequence_length)`. 134 | 135 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 136 | heads. 137 | """ 138 | 139 | image_embeds: Optional[torch.FloatTensor] = None 140 | last_hidden_state: torch.FloatTensor = None 141 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 142 | attentions: Optional[Tuple[torch.FloatTensor]] = None 143 | 144 | 145 | class SigLipVisionEmbeddings(nn.Module): 146 | def __init__(self, config: SigLipVisionConfig): 147 | super().__init__() 148 | self.config = config 149 | self.embed_dim = config.hidden_size 150 | self.image_size = config.image_size 151 | self.patch_size = config.patch_size 152 | 153 | self.patch_embedding = nn.Conv2d( 154 | in_channels=config.num_channels, 155 | out_channels=self.embed_dim, 156 | kernel_size=self.patch_size, 157 | stride=self.patch_size, 158 | padding="valid", 159 | ) 160 | 161 | self.num_patches = (self.image_size // self.patch_size) ** 2 162 | self.num_positions = self.num_patches 163 | self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) 164 | self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) 165 | 166 | def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: 167 | patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] 168 | embeddings = patch_embeds.flatten(2).transpose(1, 2) 169 | 170 | embeddings = embeddings + self.position_embedding(self.position_ids) 171 | return embeddings 172 | 173 | 174 | class SigLipAttention(nn.Module): 175 | """Multi-headed attention from 'Attention Is All You Need' paper""" 176 | 177 | # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ 178 | def __init__(self, config): 179 | super().__init__() 180 | self.config = config 181 | self.embed_dim = config.hidden_size 182 | self.num_heads = config.num_attention_heads 183 | self.head_dim = self.embed_dim // self.num_heads 184 | if self.head_dim * self.num_heads != self.embed_dim: 185 | raise ValueError( 186 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" 187 | f" {self.num_heads})." 188 | ) 189 | self.scale = self.head_dim ** -0.5 190 | self.dropout = config.attention_dropout 191 | 192 | self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) 193 | self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) 194 | self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) 195 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) 196 | 197 | def forward( 198 | self, 199 | hidden_states: torch.Tensor, 200 | attention_mask: Optional[torch.Tensor] = None, 201 | output_attentions: Optional[bool] = False, 202 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 203 | """Input shape: Batch x Time x Channel""" 204 | 205 | batch_size, q_len, _ = hidden_states.size() 206 | 207 | query_states = self.q_proj(hidden_states) 208 | key_states = self.k_proj(hidden_states) 209 | value_states = self.v_proj(hidden_states) 210 | 211 | query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) 212 | key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) 213 | value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) 214 | 215 | k_v_seq_len = key_states.shape[-2] 216 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale 217 | 218 | if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): 219 | raise ValueError( 220 | f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" 221 | f" {attn_weights.size()}" 222 | ) 223 | 224 | if attention_mask is not None: 225 | if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): 226 | raise ValueError( 227 | f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" 228 | ) 229 | attn_weights = attn_weights + attention_mask 230 | 231 | # upcast attention to fp32 232 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 233 | attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 234 | attn_output = torch.matmul(attn_weights, value_states) 235 | 236 | if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): 237 | raise ValueError( 238 | f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" 239 | f" {attn_output.size()}" 240 | ) 241 | 242 | attn_output = attn_output.transpose(1, 2).contiguous() 243 | attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) 244 | 245 | attn_output = self.out_proj(attn_output) 246 | 247 | return attn_output, attn_weights 248 | 249 | 250 | # Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->SigLip 251 | class SigLipMLP(nn.Module): 252 | def __init__(self, config): 253 | super().__init__() 254 | self.config = config 255 | self.activation_fn = ACT2FN[config.hidden_act] 256 | self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) 257 | self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) 258 | 259 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 260 | hidden_states = self.fc1(hidden_states) 261 | hidden_states = self.activation_fn(hidden_states) 262 | hidden_states = self.fc2(hidden_states) 263 | return hidden_states 264 | 265 | 266 | # Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->SigLip 267 | class SigLipEncoderLayer(nn.Module): 268 | def __init__(self, config: SigLipVisionConfig): 269 | super().__init__() 270 | self.embed_dim = config.hidden_size 271 | self.self_attn = SigLipAttention(config) 272 | self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 273 | self.mlp = SigLipMLP(config) 274 | self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) 275 | 276 | # Ignore copy 277 | def forward( 278 | self, 279 | hidden_states: torch.Tensor, 280 | attention_mask: torch.Tensor, 281 | output_attentions: Optional[bool] = False, 282 | ) -> Tuple[torch.FloatTensor]: 283 | """ 284 | Args: 285 | hidden_states (`torch.FloatTensor`): 286 | Input to the layer of shape `(batch, seq_len, embed_dim)`. 287 | attention_mask (`torch.FloatTensor`): 288 | Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. 289 | output_attentions (`bool`, *optional*, defaults to `False`): 290 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 291 | returned tensors for more detail. 292 | """ 293 | residual = hidden_states 294 | 295 | hidden_states = self.layer_norm1(hidden_states) 296 | hidden_states, attn_weights = self.self_attn( 297 | hidden_states=hidden_states, 298 | attention_mask=attention_mask, 299 | output_attentions=output_attentions, 300 | ) 301 | hidden_states = residual + hidden_states 302 | 303 | residual = hidden_states 304 | hidden_states = self.layer_norm2(hidden_states) 305 | hidden_states = self.mlp(hidden_states) 306 | hidden_states = residual + hidden_states 307 | 308 | outputs = (hidden_states,) 309 | 310 | if output_attentions: 311 | outputs += (attn_weights,) 312 | 313 | return outputs 314 | 315 | 316 | class SigLipPreTrainedModel(PreTrainedModel): 317 | """ 318 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 319 | models. 320 | """ 321 | 322 | config_class = SigLipVisionConfig 323 | base_model_prefix = "siglip" 324 | supports_gradient_checkpointing = True 325 | 326 | def _init_weights(self, module): 327 | """Initialize the weights""" 328 | pass 329 | 330 | 331 | # Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->SigLip 332 | class SigLipEncoder(nn.Module): 333 | """ 334 | Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a 335 | [`SigLipEncoderLayer`]. 336 | 337 | Args: 338 | config: SigLipVisionConfig 339 | """ 340 | 341 | def __init__(self, config: SigLipVisionConfig): 342 | super().__init__() 343 | self.config = config 344 | self.layers = nn.ModuleList([SigLipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) 345 | self.gradient_checkpointing = False 346 | 347 | # Ignore copy 348 | def forward( 349 | self, 350 | inputs_embeds, 351 | attention_mask: Optional[torch.Tensor] = None, 352 | output_attentions: Optional[bool] = None, 353 | output_hidden_states: Optional[bool] = None, 354 | return_dict: Optional[bool] = None, 355 | ) -> Union[Tuple, BaseModelOutput]: 356 | r""" 357 | Args: 358 | inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 359 | Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. 360 | This is useful if you want more control over how to convert `input_ids` indices into associated vectors 361 | than the model's internal embedding lookup matrix. 362 | attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): 363 | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: 364 | 365 | - 1 for tokens that are **not masked**, 366 | - 0 for tokens that are **masked**. 367 | 368 | [What are attention masks?](../glossary#attention-mask) 369 | output_attentions (`bool`, *optional*): 370 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under 371 | returned tensors for more detail. 372 | output_hidden_states (`bool`, *optional*): 373 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors 374 | for more detail. 375 | return_dict (`bool`, *optional*): 376 | Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. 377 | """ 378 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 379 | output_hidden_states = ( 380 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 381 | ) 382 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 383 | 384 | encoder_states = () if output_hidden_states else None 385 | all_attentions = () if output_attentions else None 386 | 387 | hidden_states = inputs_embeds 388 | for encoder_layer in self.layers: 389 | if output_hidden_states: 390 | encoder_states = encoder_states + (hidden_states,) 391 | if self.gradient_checkpointing and self.training: 392 | layer_outputs = self._gradient_checkpointing_func( 393 | encoder_layer.__call__, 394 | hidden_states, 395 | attention_mask, 396 | output_attentions, 397 | ) 398 | else: 399 | layer_outputs = encoder_layer( 400 | hidden_states, 401 | attention_mask, 402 | output_attentions=output_attentions, 403 | ) 404 | 405 | hidden_states = layer_outputs[0] 406 | 407 | if output_attentions: 408 | all_attentions = all_attentions + (layer_outputs[1],) 409 | 410 | if output_hidden_states: 411 | encoder_states = encoder_states + (hidden_states,) 412 | 413 | if not return_dict: 414 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) 415 | return BaseModelOutput( 416 | last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions 417 | ) 418 | 419 | 420 | class SigLipVisionTransformer(nn.Module): 421 | def __init__(self, config: SigLipVisionConfig): 422 | super().__init__() 423 | self.config = config 424 | embed_dim = config.hidden_size 425 | 426 | self.embeddings = SigLipVisionEmbeddings(config) 427 | self.encoder = SigLipEncoder(config) 428 | self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) 429 | self.head = SigLipMultiheadAttentionPoolingHead(config) 430 | 431 | def forward( 432 | self, 433 | pixel_values, 434 | output_attentions: Optional[bool] = None, 435 | output_hidden_states: Optional[bool] = None, 436 | return_dict: Optional[bool] = None, 437 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 438 | r""" 439 | Returns: 440 | 441 | """ 442 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 443 | output_hidden_states = ( 444 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 445 | ) 446 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 447 | 448 | hidden_states = self.embeddings(pixel_values) 449 | 450 | encoder_outputs = self.encoder( 451 | inputs_embeds=hidden_states, 452 | output_attentions=output_attentions, 453 | output_hidden_states=output_hidden_states, 454 | return_dict=return_dict, 455 | ) 456 | 457 | last_hidden_state = encoder_outputs[0] 458 | last_hidden_state = self.post_layernorm(last_hidden_state) 459 | 460 | pooled_output = self.head(last_hidden_state) 461 | 462 | if not return_dict: 463 | return (last_hidden_state, pooled_output) + encoder_outputs[1:] 464 | 465 | return BaseModelOutputWithPooling( 466 | last_hidden_state=last_hidden_state, 467 | pooler_output=pooled_output, 468 | hidden_states=encoder_outputs.hidden_states, 469 | attentions=encoder_outputs.attentions, 470 | ) 471 | 472 | 473 | class SigLipMultiheadAttentionPoolingHead(nn.Module): 474 | """Multihead Attention Pooling.""" 475 | 476 | def __init__(self, config: SigLipVisionConfig): 477 | super().__init__() 478 | 479 | self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) 480 | self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) 481 | self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 482 | self.mlp = SigLipMLP(config) 483 | 484 | def forward(self, hidden_state): 485 | batch_size = hidden_state.shape[0] 486 | probe = self.probe.repeat(batch_size, 1, 1) 487 | 488 | hidden_state = self.attention(probe, hidden_state, hidden_state)[0] 489 | 490 | residual = hidden_state 491 | hidden_state = self.layernorm(hidden_state) 492 | hidden_state = residual + self.mlp(hidden_state) 493 | 494 | return hidden_state[:, 0] 495 | 496 | 497 | class SigLipVisionModel(SigLipPreTrainedModel): 498 | config_class = SigLipVisionConfig 499 | main_input_name = "pixel_values" 500 | _no_split_modules = ["SigLipEncoderLayer"] 501 | 502 | def __init__(self, config: SigLipVisionConfig): 503 | super().__init__(config) 504 | 505 | self.vision_model = SigLipVisionTransformer(config) 506 | 507 | # Initialize weights and apply final processing 508 | self.post_init() 509 | 510 | def get_input_embeddings(self) -> nn.Module: 511 | return self.vision_model.embeddings.patch_embedding 512 | 513 | def forward( 514 | self, 515 | pixel_values, 516 | output_attentions: Optional[bool] = None, 517 | output_hidden_states: Optional[bool] = None, 518 | return_dict: Optional[bool] = None, 519 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 520 | r""" 521 | Returns: 522 | 523 | Examples: 524 | 525 | ```python 526 | >>> from PIL import Image 527 | >>> import requests 528 | >>> from transformers import AutoProcessor, SigLipVisionModel 529 | 530 | >>> model = SigLipVisionModel.from_pretrained("google/siglip-base-patch16-224") 531 | >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") 532 | 533 | >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" 534 | >>> image = Image.open(requests.get(url, stream=True).raw) 535 | 536 | >>> inputs = processor(images=image, return_tensors="pt") 537 | 538 | >>> outputs = model(**inputs) 539 | >>> last_hidden_state = outputs.last_hidden_state 540 | >>> pooled_output = outputs.pooler_output # pooled features 541 | ```""" 542 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 543 | 544 | return self.vision_model( 545 | pixel_values=pixel_values, 546 | output_attentions=output_attentions, 547 | output_hidden_states=output_hidden_states, 548 | return_dict=return_dict, 549 | ) 550 | 551 | 552 | class SigLipVisionTower(nn.Module): 553 | def __init__(self, vision_tower, vision_tower_cfg, delay_load=False): 554 | super().__init__() 555 | 556 | self.is_loaded = False 557 | 558 | self.config = SigLipVisionConfig() 559 | 560 | self.vision_tower_name = vision_tower 561 | 562 | self.image_processor = SigLipImageProcessor() 563 | 564 | if not delay_load: 565 | self.load_model() 566 | else: 567 | self.cfg_only = self.config 568 | 569 | def load_model(self): 570 | if self.is_loaded: 571 | return 572 | 573 | self.vision_tower = SigLipVisionModel.from_pretrained(self.vision_tower_name) 574 | 575 | del self.vision_tower.vision_model.encoder.layers[-1:] 576 | self.vision_tower.vision_model.head = nn.Identity() 577 | self.vision_tower.requires_grad_(False) 578 | self.vision_tower.eval() 579 | 580 | self.is_loaded = True 581 | 582 | @torch.no_grad() 583 | def forward(self, images): 584 | if type(images) is list: 585 | image_features = [] 586 | for image in images: 587 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), 588 | output_hidden_states=True) 589 | image_feature = image_forward_out.hidden_states[-1].to(image.dtype) 590 | assert image_features.shape[-2] == 729 591 | image_features.append(image_feature) 592 | else: 593 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), 594 | output_hidden_states=True) 595 | image_features = image_forward_outs.hidden_states[-1].to(images.dtype) 596 | assert image_features.shape[-2] == 729 597 | 598 | return image_features 599 | 600 | @property 601 | def dummy_feature(self): 602 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 603 | 604 | @property 605 | def dtype(self): 606 | for p in self.vision_tower.parameters(): 607 | return p.dtype 608 | 609 | @property 610 | def device(self): 611 | for p in self.vision_tower.parameters(): 612 | return p.device 613 | 614 | @property 615 | def hidden_size(self): 616 | return self.config.hidden_size 617 | 618 | @property 619 | def num_patches(self): 620 | return (self.config.image_size // self.config.patch_size) ** 2 621 | -------------------------------------------------------------------------------- /cerule/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import re 2 | import math 3 | from torch import nn 4 | from functools import partial 5 | from timm.layers.norm_act import LayerNormAct2d 6 | from torchvision.ops.misc import SqueezeExcitation as SELayer 7 | from torchvision.models.mobilenetv3 import InvertedResidual, InvertedResidualConfig 8 | 9 | 10 | class IdentityMap(nn.Module): 11 | def __init__(self): 12 | super().__init__() 13 | 14 | def forward(self, x, *args, **kwargs): 15 | return x 16 | 17 | @property 18 | def config(self): 19 | return {"mm_projector_type": 'identity'} 20 | 21 | 22 | class Minigpt(nn.Module): 23 | def __init__(self, config=None): 24 | super(Minigpt, self).__init__() 25 | # c*4 is the input size, and c is the output size for the linear layer 26 | inc, ouc = config.mm_hidden_size, config.hidden_size 27 | self.linear = nn.Linear(inc * 4, ouc) 28 | 29 | def forward(self, x): 30 | # x is the input tensor with shape [b, num_tokens, c] 31 | b, num_tokens, c = x.shape 32 | 33 | # Check if num_tokens is divisible by 4 34 | if num_tokens % 4 != 0: 35 | raise ValueError("num_tokens must be divisible by 4") 36 | 37 | # Reshape x to [b, num_tokens/4, c*4] 38 | x = x.view(b, num_tokens // 4, c * 4) 39 | 40 | # Apply the linear transformation 41 | x = self.linear(x) 42 | return x 43 | 44 | 45 | class Vanilla(nn.Module): 46 | def __init__(self, config=None): 47 | super(Vanilla, self).__init__() 48 | # c*4 is the input size, and c is the output size for the linear layer 49 | inc, ouc = config.mm_hidden_size, config.hidden_size 50 | self.linear = nn.Linear(inc * 4, ouc) 51 | 52 | def forward(self, x): 53 | b, num_tokens, c = x.shape 54 | 55 | # Check if num_tokens is divisible by 4 56 | if num_tokens % 4 != 0: 57 | raise ValueError("num_tokens must be divisible by 4") 58 | 59 | # First, reshape to [b, num_tokens//4, 4, c] 60 | x = x.view(b, num_tokens // 4, 4, c) 61 | 62 | # Then, permute to interleave the tokens 63 | x = x.permute(0, 1, 3, 2).contiguous() 64 | 65 | # Finally, reshape to [b, num_tokens//4, c*4] to interleave features of 4 tokens 66 | x = x.view(b, num_tokens // 4, c * 4) 67 | 68 | # Apply the linear transformation 69 | x = self.linear(x) 70 | return x 71 | 72 | 73 | class LDPBlock(nn.Module): 74 | # Lightweight Downsample Projector Block 75 | 76 | def __init__(self, config=None): 77 | super().__init__() 78 | 79 | inc, ouc = config.mm_hidden_size, config.hidden_size 80 | layer_norm = partial(LayerNormAct2d, act_layer=None) 81 | se_layer = partial(SELayer, scale_activation=nn.Hardsigmoid) 82 | self.mlp = nn.Sequential( 83 | nn.Identity(), nn.Linear(inc, ouc), nn.GELU(), nn.Linear(ouc, ouc) 84 | ) 85 | self.mb_block = nn.Sequential( 86 | nn.Identity(), 87 | InvertedResidual(InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 1, 1, 1), layer_norm, se_layer), 88 | InvertedResidual(InvertedResidualConfig(ouc, 3, ouc, ouc, True, "HS", 2, 1, 1), layer_norm, se_layer) 89 | ) 90 | 91 | def forward(self, x): 92 | b, num_tokens, c = x.shape 93 | h = int(math.sqrt(num_tokens)) 94 | x = self.mlp(x) 95 | x = x.permute(0, 2, 1).reshape(b, -1, h, h) 96 | x = self.mb_block(x) 97 | x = x.flatten(2).permute(0, 2, 1) 98 | return x 99 | 100 | 101 | class LDPNetProjector(nn.Module): 102 | 103 | def __init__(self, config=None): 104 | super().__init__() 105 | self.model = LDPBlock(config) 106 | 107 | def forward(self, x): 108 | return self.model(x) 109 | 110 | 111 | class SPP(nn.Module): 112 | 113 | def __init__(self, config=None, projector_type='v1'): 114 | super().__init__() 115 | 116 | self.projector_type = projector_type 117 | 118 | inc, ouc = config.mm_hidden_size, config.hidden_size 119 | self.linear_0 = nn.Linear(inc, inc) 120 | 121 | self.linear_1 = nn.Linear(inc, ouc) 122 | 123 | self.pooling = nn.AvgPool2d(kernel_size=2) 124 | 125 | self.linear_2 = nn.Linear(ouc, ouc) 126 | 127 | def forward(self, x): 128 | b, num_tokens, c = x.shape 129 | h = int(math.sqrt(num_tokens)) 130 | if 'v1' in self.projector_type: 131 | x = self.linear_1(x) 132 | x = x.permute(0, 2, 1).reshape(b, -1, h, h) 133 | x = self.pooling(x) 134 | x = x.flatten(2).permute(0, 2, 1) 135 | x = self.linear_2(x) 136 | elif 'v2' in self.projector_type: 137 | x = self.linear_1(x) 138 | x = self.linear_2(x) 139 | x = x.permute(0, 2, 1).reshape(b, -1, h, h) 140 | x = self.pooling(x) 141 | x = x.flatten(2).permute(0, 2, 1) 142 | elif 'v3' in self.projector_type: 143 | x = self.linear_0(x) 144 | x = x.permute(0, 2, 1).reshape(b, -1, h, h) 145 | x = self.pooling(x) 146 | x = x.flatten(2).permute(0, 2, 1) 147 | x = self.linear_1(x) 148 | x = self.linear_2(x) 149 | return x 150 | 151 | 152 | def build_vision_projector(config, delay_load=False, **kwargs): 153 | projector_type = getattr(config, 'mm_projector_type', 'mlp2x_gelu') 154 | 155 | if projector_type == 'linear': 156 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 157 | 158 | elif projector_type.startswith('mlp'): 159 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 160 | if mlp_gelu_match: 161 | mlp_depth = int(mlp_gelu_match.group(1)) 162 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 163 | for _ in range(1, mlp_depth): 164 | modules.append(nn.GELU()) 165 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 166 | return nn.Sequential(*modules) 167 | 168 | elif projector_type.startswith('spp'): 169 | return SPP(config, projector_type) 170 | 171 | elif projector_type == 'ldp': 172 | return LDPNetProjector(config) 173 | 174 | elif projector_type == 'vanilla': 175 | return Vanilla(config) 176 | 177 | elif projector_type == 'minigpt': 178 | return Minigpt(config) 179 | 180 | elif projector_type == 'identity': 181 | return IdentityMap() 182 | 183 | raise ValueError(f'Unknown projector type: {projector_type}') 184 | -------------------------------------------------------------------------------- /cerule/serve/cli.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import requests 4 | 5 | from PIL import Image 6 | from io import BytesIO 7 | from transformers import TextStreamer 8 | 9 | from cerule.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 10 | from cerule.conversation import conv_templates, SeparatorStyle 11 | from cerule.model.builder import load_pretrained_model 12 | from cerule.util.utils import disable_torch_init 13 | from cerule.util.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, \ 14 | KeywordsStoppingCriteria 15 | 16 | 17 | def load_image(image_file): 18 | if image_file.startswith('http://') or image_file.startswith('https://'): 19 | response = requests.get(image_file) 20 | image = Image.open(BytesIO(response.content)).convert('RGB') 21 | else: 22 | image = Image.open(image_file).convert('RGB') 23 | return image 24 | 25 | 26 | def main(args): 27 | # Model 28 | disable_torch_init() 29 | 30 | model_name = get_model_name_from_path(args.model_path) 31 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, 32 | args.model_type, args.load_8bit, 33 | args.load_4bit, device=args.device) 34 | 35 | conv_mode = "gemma_instruct" 36 | 37 | if args.conv_mode is not None and conv_mode != args.conv_mode: 38 | print( 39 | '[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, 40 | args.conv_mode, 41 | args.conv_mode)) 42 | else: 43 | args.conv_mode = conv_mode 44 | 45 | conv = conv_templates[args.conv_mode].copy() 46 | roles = conv.roles 47 | 48 | image = load_image(args.image_file) 49 | # Similar operation in model_worker.py 50 | image_tensor = process_images([image], image_processor, model.config) 51 | if type(image_tensor) is list: 52 | image_tensor = [image.to(model.device, dtype=model.dtype) for image in image_tensor] 53 | else: 54 | image_tensor = image_tensor.to(model.device, dtype=model.dtype) 55 | 56 | while True: 57 | try: 58 | inp = input(f"{roles[0]}: ") 59 | except EOFError: 60 | inp = "" 61 | if not inp: 62 | print("exit...") 63 | break 64 | 65 | print(f"{roles[1]}: ", end="") 66 | 67 | if image is not None: 68 | # first message 69 | inp = DEFAULT_IMAGE_TOKEN + '\n' + inp 70 | conv.append_message(conv.roles[0], inp) 71 | image = None 72 | else: 73 | conv.append_message(conv.roles[0], inp) 74 | conv.append_message(conv.roles[1], None) 75 | prompt = conv.get_prompt() 76 | 77 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to( 78 | model.device) 79 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 80 | keywords = [stop_str] 81 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 82 | streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False) 83 | 84 | with torch.inference_mode(): 85 | output_ids = model.generate( 86 | input_ids, 87 | images=image_tensor, 88 | do_sample=True if args.temperature > 0 else False, 89 | temperature=args.temperature, 90 | max_new_tokens=args.max_new_tokens, 91 | streamer=streamer, 92 | use_cache=False, # Keep use_cache=False. use_cache=True raises some weird error 93 | stopping_criteria=[stopping_criteria]) 94 | 95 | outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() 96 | conv.messages[-1][-1] = outputs 97 | 98 | if args.debug: 99 | print("\n", {"prompt": prompt, "outputs": outputs}, "\n") 100 | 101 | 102 | if __name__ == "__main__": 103 | parser = argparse.ArgumentParser() 104 | parser.add_argument("--model-path", type=str, default=None) 105 | parser.add_argument("--model-base", type=str, default=None) 106 | parser.add_argument("--model-type", type=str, default="gemma") 107 | parser.add_argument("--image-file", type=str, required=True) 108 | parser.add_argument("--device", type=str, default="cuda") 109 | parser.add_argument("--conv-mode", type=str, default=None) 110 | parser.add_argument("--temperature", type=float, default=0.2) 111 | parser.add_argument("--max-new-tokens", type=int, default=512) 112 | parser.add_argument("--load-8bit", action="store_true") 113 | parser.add_argument("--load-4bit", action="store_true") 114 | parser.add_argument("--debug", action="store_true") 115 | args = parser.parse_args() 116 | main(args) 117 | -------------------------------------------------------------------------------- /cerule/serve/controller.py: -------------------------------------------------------------------------------- 1 | """ 2 | A controller manages distributed workers. 3 | It sends worker addresses to clients. 4 | """ 5 | import argparse 6 | import dataclasses 7 | import threading 8 | import json 9 | import time 10 | import numpy as np 11 | import requests 12 | import uvicorn 13 | 14 | from typing import List 15 | from enum import Enum, auto 16 | from fastapi import FastAPI, Request 17 | from fastapi.responses import StreamingResponse 18 | 19 | from cerule.constants import CONTROLLER_HEART_BEAT_EXPIRATION 20 | from cerule.util.utils import build_logger, server_error_msg 21 | 22 | logger = build_logger("controller", "controller.log") 23 | 24 | 25 | class DispatchMethod(Enum): 26 | LOTTERY = auto() 27 | SHORTEST_QUEUE = auto() 28 | 29 | @classmethod 30 | def from_str(cls, name): 31 | if name == "lottery": 32 | return cls.LOTTERY 33 | elif name == "shortest_queue": 34 | return cls.SHORTEST_QUEUE 35 | else: 36 | raise ValueError(f"Invalid dispatch method") 37 | 38 | 39 | @dataclasses.dataclass 40 | class WorkerInfo: 41 | model_names: List[str] 42 | speed: int 43 | queue_length: int 44 | check_heart_beat: bool 45 | last_heart_beat: str 46 | 47 | 48 | def heart_beat_controller(controller): 49 | while True: 50 | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) 51 | controller.remove_stable_workers_by_expiration() 52 | 53 | 54 | class Controller: 55 | def __init__(self, dispatch_method: str): 56 | # Dict[str -> WorkerInfo] 57 | self.worker_info = {} 58 | self.dispatch_method = DispatchMethod.from_str(dispatch_method) 59 | 60 | self.heart_beat_thread = threading.Thread( 61 | target=heart_beat_controller, args=(self,)) 62 | self.heart_beat_thread.start() 63 | 64 | logger.info("Init controller") 65 | 66 | def register_worker(self, worker_name: str, check_heart_beat: bool, 67 | worker_status: dict): 68 | if worker_name not in self.worker_info: 69 | logger.info(f"Register a new worker: {worker_name}") 70 | else: 71 | logger.info(f"Register an existing worker: {worker_name}") 72 | 73 | if not worker_status: 74 | worker_status = self.get_worker_status(worker_name) 75 | if not worker_status: 76 | return False 77 | 78 | self.worker_info[worker_name] = WorkerInfo( 79 | worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], 80 | check_heart_beat, time.time()) 81 | 82 | logger.info(f"Register done: {worker_name}, {worker_status}") 83 | return True 84 | 85 | def get_worker_status(self, worker_name: str): 86 | try: 87 | r = requests.post(worker_name + "/worker_get_status", timeout=5) 88 | except requests.exceptions.RequestException as e: 89 | logger.error(f"Get status fails: {worker_name}, {e}") 90 | return None 91 | 92 | if r.status_code != 200: 93 | logger.error(f"Get status fails: {worker_name}, {r}") 94 | return None 95 | 96 | return r.json() 97 | 98 | def remove_worker(self, worker_name: str): 99 | del self.worker_info[worker_name] 100 | 101 | def refresh_all_workers(self): 102 | old_info = dict(self.worker_info) 103 | self.worker_info = {} 104 | 105 | for w_name, w_info in old_info.items(): 106 | if not self.register_worker(w_name, w_info.check_heart_beat, None): 107 | logger.info(f"Remove stale worker: {w_name}") 108 | 109 | def list_models(self): 110 | model_names = set() 111 | 112 | for w_name, w_info in self.worker_info.items(): 113 | model_names.update(w_info.model_names) 114 | 115 | return list(model_names) 116 | 117 | def get_worker_address(self, model_name: str): 118 | if self.dispatch_method == DispatchMethod.LOTTERY: 119 | worker_names = [] 120 | worker_speeds = [] 121 | for w_name, w_info in self.worker_info.items(): 122 | if model_name in w_info.model_names: 123 | worker_names.append(w_name) 124 | worker_speeds.append(w_info.speed) 125 | worker_speeds = np.array(worker_speeds, dtype=np.float32) 126 | norm = np.sum(worker_speeds) 127 | if norm < 1e-4: 128 | return "" 129 | worker_speeds = worker_speeds / norm 130 | 131 | pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds) 132 | worker_name = worker_names[pt] 133 | return worker_name 134 | 135 | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: 136 | worker_names = [] 137 | worker_qlen = [] 138 | for w_name, w_info in self.worker_info.items(): 139 | if model_name in w_info.model_names: 140 | worker_names.append(w_name) 141 | worker_qlen.append(w_info.queue_length / w_info.speed) 142 | if len(worker_names) == 0: 143 | return "" 144 | min_index = np.argmin(worker_qlen) 145 | w_name = worker_names[min_index] 146 | self.worker_info[w_name].queue_length += 1 147 | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") 148 | return w_name 149 | else: 150 | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") 151 | 152 | def receive_heart_beat(self, worker_name: str, queue_length: int): 153 | if worker_name not in self.worker_info: 154 | logger.info(f"Receive unknown heart beat. {worker_name}") 155 | return False 156 | 157 | self.worker_info[worker_name].queue_length = queue_length 158 | self.worker_info[worker_name].last_heart_beat = time.time() 159 | logger.info(f"Receive heart beat. {worker_name}") 160 | return True 161 | 162 | def remove_stable_workers_by_expiration(self): 163 | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION 164 | to_delete = [] 165 | for worker_name, w_info in self.worker_info.items(): 166 | if w_info.check_heart_beat and w_info.last_heart_beat < expire: 167 | to_delete.append(worker_name) 168 | 169 | for worker_name in to_delete: 170 | self.remove_worker(worker_name) 171 | 172 | def worker_api_generate_stream(self, params): 173 | worker_addr = self.get_worker_address(params["model"]) 174 | if not worker_addr: 175 | logger.info(f"no worker: {params['model']}") 176 | ret = { 177 | "text": server_error_msg, 178 | "error_code": 2, 179 | } 180 | yield json.dumps(ret).encode() + b"\0" 181 | 182 | try: 183 | response = requests.post(worker_addr + "/worker_generate_stream", 184 | json=params, stream=True, timeout=5) 185 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 186 | if chunk: 187 | yield chunk + b"\0" 188 | except requests.exceptions.RequestException as e: 189 | logger.info(f"worker timeout: {worker_addr}") 190 | ret = { 191 | "text": server_error_msg, 192 | "error_code": 3, 193 | } 194 | yield json.dumps(ret).encode() + b"\0" 195 | 196 | # Let the controller act as a worker to achieve hierarchical 197 | # management. This can be used to connect isolated sub networks. 198 | def worker_api_get_status(self): 199 | model_names = set() 200 | speed = 0 201 | queue_length = 0 202 | 203 | for w_name in self.worker_info: 204 | worker_status = self.get_worker_status(w_name) 205 | if worker_status is not None: 206 | model_names.update(worker_status["model_names"]) 207 | speed += worker_status["speed"] 208 | queue_length += worker_status["queue_length"] 209 | 210 | return { 211 | "model_names": list(model_names), 212 | "speed": speed, 213 | "queue_length": queue_length, 214 | } 215 | 216 | 217 | app = FastAPI() 218 | 219 | 220 | @app.post("/register_worker") 221 | async def register_worker(request: Request): 222 | data = await request.json() 223 | controller.register_worker( 224 | data["worker_name"], data["check_heart_beat"], 225 | data.get("worker_status", None)) 226 | 227 | 228 | @app.post("/refresh_all_workers") 229 | async def refresh_all_workers(): 230 | models = controller.refresh_all_workers() 231 | 232 | 233 | @app.post("/list_models") 234 | async def list_models(): 235 | models = controller.list_models() 236 | return {"models": models} 237 | 238 | 239 | @app.post("/get_worker_address") 240 | async def get_worker_address(request: Request): 241 | data = await request.json() 242 | addr = controller.get_worker_address(data["model"]) 243 | return {"address": addr} 244 | 245 | 246 | @app.post("/receive_heart_beat") 247 | async def receive_heart_beat(request: Request): 248 | data = await request.json() 249 | exist = controller.receive_heart_beat( 250 | data["worker_name"], data["queue_length"]) 251 | return {"exist": exist} 252 | 253 | 254 | @app.post("/worker_generate_stream") 255 | async def worker_api_generate_stream(request: Request): 256 | params = await request.json() 257 | generator = controller.worker_api_generate_stream(params) 258 | return StreamingResponse(generator) 259 | 260 | 261 | @app.post("/worker_get_status") 262 | async def worker_api_get_status(request: Request): 263 | return controller.worker_api_get_status() 264 | 265 | 266 | if __name__ == "__main__": 267 | parser = argparse.ArgumentParser() 268 | parser.add_argument("--host", type=str, default="localhost") 269 | parser.add_argument("--port", type=int, default=21001) 270 | parser.add_argument("--dispatch-method", type=str, choices=["lottery", "shortest_queue"], default="shortest_queue") 271 | args = parser.parse_args() 272 | logger.info(f"args: {args}") 273 | 274 | controller = Controller(args.dispatch_method) 275 | log_config = uvicorn.config.LOGGING_CONFIG 276 | log_config['handlers']['default']['stream'] = 'ext://sys.stdout' 277 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 278 | -------------------------------------------------------------------------------- /cerule/serve/examples/example_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/cerule/serve/examples/example_1.png -------------------------------------------------------------------------------- /cerule/serve/examples/example_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/cerule/serve/examples/example_2.png -------------------------------------------------------------------------------- /cerule/serve/gradio_web_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import os 5 | import time 6 | import gradio as gr 7 | import requests 8 | import hashlib 9 | import pypandoc 10 | import base64 11 | 12 | from io import BytesIO 13 | 14 | from cerule.conversation import (default_conversation, conv_templates, SeparatorStyle) 15 | from cerule.constants import LOGDIR 16 | from cerule.util.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg) 17 | 18 | logger = build_logger("gradio_web_server", "gradio_web_server.log") 19 | 20 | headers = {"User-Agent": "cerule Client"} 21 | 22 | no_change_btn = gr.update() 23 | enable_btn = gr.update(interactive=True) 24 | disable_btn = gr.update(interactive=False) 25 | 26 | priority = { 27 | "cerule": "aaaaaaa", 28 | } 29 | 30 | 31 | def get_conv_log_filename(): 32 | t = datetime.datetime.now() 33 | name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") 34 | return name 35 | 36 | 37 | def get_model_list(): 38 | ret = requests.post(args.controller_url + "/refresh_all_workers") 39 | assert ret.status_code == 200 40 | ret = requests.post(args.controller_url + "/list_models") 41 | models = ret.json()["models"] 42 | models.sort(key=lambda x: priority.get(x, x)) 43 | logger.info(f"Models: {models}") 44 | return models 45 | 46 | 47 | get_window_url_params = """ 48 | function() { 49 | const params = new URLSearchParams(window.location.search); 50 | url_params = Object.fromEntries(params); 51 | console.log(url_params); 52 | return url_params; 53 | } 54 | """ 55 | 56 | 57 | def load_demo(url_params, request: gr.Request): 58 | logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") 59 | 60 | dropdown_update = gr.update(visible=True) 61 | if "model" in url_params: 62 | model = url_params["model"] 63 | if model in models: 64 | dropdown_update = gr.update( 65 | value=model, visible=True) 66 | 67 | state = default_conversation.copy() 68 | return state, dropdown_update 69 | 70 | 71 | def load_demo_refresh_model_list(request: gr.Request): 72 | logger.info(f"load_demo. ip: {request.client.host}") 73 | models = get_model_list() 74 | state = default_conversation.copy() 75 | dropdown_update = gr.update( 76 | choices=models, 77 | value=models[0] if len(models) > 0 else "" 78 | ) 79 | return state, dropdown_update 80 | 81 | 82 | def vote_last_response(state, vote_type, model_selector, request: gr.Request): 83 | with open(get_conv_log_filename(), "a") as fout: 84 | data = { 85 | "tstamp": round(time.time(), 4), 86 | "type": vote_type, 87 | "model": model_selector, 88 | "state": state.dict(), 89 | "ip": request.client.host, 90 | } 91 | fout.write(json.dumps(data) + "\n") 92 | 93 | 94 | def upvote_last_response(state, model_selector, request: gr.Request): 95 | logger.info(f"upvote. ip: {request.client.host}") 96 | vote_last_response(state, "upvote", model_selector, request) 97 | return ("",) + (disable_btn,) * 3 98 | 99 | 100 | def downvote_last_response(state, model_selector, request: gr.Request): 101 | logger.info(f"downvote. ip: {request.client.host}") 102 | vote_last_response(state, "downvote", model_selector, request) 103 | return ("",) + (disable_btn,) * 3 104 | 105 | 106 | def flag_last_response(state, model_selector, request: gr.Request): 107 | logger.info(f"flag. ip: {request.client.host}") 108 | vote_last_response(state, "flag", model_selector, request) 109 | return ("",) + (disable_btn,) * 3 110 | 111 | 112 | def regenerate(state, image_process_mode, request: gr.Request): 113 | logger.info(f"regenerate. ip: {request.client.host}") 114 | state.messages[-1][-1] = None 115 | prev_human_msg = state.messages[-2] 116 | if type(prev_human_msg[1]) in (tuple, list): 117 | prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) 118 | state.skip_next = False 119 | return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 3 120 | 121 | 122 | def clear_history(request: gr.Request): 123 | logger.info(f"clear_history. ip: {request.client.host}") 124 | state = default_conversation.copy() 125 | return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 3 126 | 127 | 128 | def save_conversation(conversation): 129 | print("save_conversation_wrapper is called") 130 | html_content = "" 131 | 132 | for role, message in conversation.messages: 133 | if isinstance(message, str): # only text 134 | html_content += f"

{role}: {message}

" 135 | elif isinstance(message, tuple): # text+image 136 | text, image_obj, _ = message 137 | 138 | # add text 139 | if text: 140 | html_content += f"

{role}: {text}

" 141 | 142 | # add image 143 | buffered = BytesIO() 144 | image_obj.save(buffered, format="PNG") 145 | encoded_image = base64.b64encode(buffered.getvalue()).decode() 146 | html_content += f'
' 147 | 148 | html_content += "" 149 | 150 | doc_path = "tmp/conversation.docx" 151 | pypandoc.convert_text(html_content, 'docx', format='html', outputfile=doc_path) 152 | return doc_path 153 | 154 | 155 | def add_text(state, text, image, image_process_mode, request: gr.Request): 156 | logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") 157 | if len(text) <= 0 and image is None: 158 | state.skip_next = True 159 | return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 3 160 | if args.moderate: 161 | flagged = violates_moderation(text) 162 | if flagged: 163 | state.skip_next = True 164 | return (state, state.to_gradio_chatbot(), moderation_msg, None) + ( 165 | no_change_btn,) * 3 166 | 167 | text = text[:1536] # Hard cut-off 168 | if image is not None: 169 | text = text[:1200] # Hard cut-off for images 170 | if '' not in text: 171 | # text = '' + text 172 | text = text + '\n' 173 | text = (text, image, image_process_mode) 174 | if len(state.get_images(return_pil=True)) > 0: 175 | state = default_conversation.copy() 176 | logger.info(f"Input Text: {text}") 177 | state.append_message(state.roles[0], text) 178 | state.append_message(state.roles[1], None) 179 | state.skip_next = False 180 | return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 3 181 | 182 | 183 | def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request): 184 | logger.info(f"http_bot. ip: {request.client.host}") 185 | start_tstamp = time.time() 186 | model_name = model_selector 187 | 188 | if state.skip_next: 189 | # This generate call is skipped due to invalid inputs 190 | yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 3 191 | return 192 | 193 | if len(state.messages) == state.offset + 2: 194 | template_name = "cerule" 195 | new_state = conv_templates[template_name].copy() 196 | new_state.append_message(new_state.roles[0], state.messages[-2][1]) 197 | new_state.append_message(new_state.roles[1], None) 198 | state = new_state 199 | 200 | logger.info(f"Processed Input Text: {state.messages[-2][1]}") 201 | # Query worker address 202 | controller_url = args.controller_url 203 | ret = requests.post(controller_url + "/get_worker_address", 204 | json={"model": model_name}) 205 | worker_addr = ret.json()["address"] 206 | logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") 207 | 208 | # No available worker 209 | if worker_addr == "": 210 | state.messages[-1][-1] = server_error_msg 211 | yield (state, state.to_gradio_chatbot(), enable_btn, enable_btn, enable_btn) 212 | return 213 | 214 | # Construct prompt 215 | prompt = state.get_prompt() 216 | 217 | all_images = state.get_images(return_pil=True) 218 | all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] 219 | for image, hash in zip(all_images, all_image_hash): 220 | t = datetime.datetime.now() 221 | filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg") 222 | if not os.path.isfile(filename): 223 | os.makedirs(os.path.dirname(filename), exist_ok=True) 224 | image.save(filename) 225 | 226 | # Make requests 227 | pload = { 228 | "model": model_name, 229 | "prompt": prompt, 230 | "temperature": float(temperature), 231 | "top_p": float(top_p), 232 | "max_new_tokens": min(int(max_new_tokens), 1536), 233 | "stop": state.sep if state.sep_style in [SeparatorStyle.PLAIN, ] else state.sep2, 234 | "images": f'List of {len(state.get_images())} images: {all_image_hash}', 235 | } 236 | logger.info(f"==== request ====\n{pload}") 237 | 238 | pload['images'] = state.get_images() 239 | print('=========> get_images') 240 | state.messages[-1][-1] = "▌" 241 | yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 242 | print('=========> state', state.messages[-1][-1]) 243 | 244 | try: 245 | # Stream output 246 | response = requests.post(worker_addr + "/worker_generate_stream", 247 | headers=headers, json=pload, stream=True, timeout=1000) 248 | print("====> response ok") 249 | print("====> response dir", dir(response)) 250 | print("====> response", response) 251 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 252 | if chunk: 253 | data = json.loads(chunk.decode()) 254 | if data["error_code"] == 0: 255 | output = data["text"][len(prompt):].strip() 256 | state.messages[-1][-1] = output + "▌" 257 | yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 3 258 | else: 259 | output = data["text"] + f" (error_code: {data['error_code']})" 260 | state.messages[-1][-1] = output 261 | yield (state, state.to_gradio_chatbot()) + (enable_btn, enable_btn, enable_btn) 262 | return 263 | time.sleep(0.03) 264 | except requests.exceptions.RequestException as e: 265 | state.messages[-1][-1] = server_error_msg 266 | yield (state, state.to_gradio_chatbot()) + (enable_btn, enable_btn, enable_btn) 267 | return 268 | 269 | state.messages[-1][-1] = state.messages[-1][-1][:-1] 270 | yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 3 271 | 272 | finish_tstamp = time.time() 273 | logger.info(f"{output}") 274 | 275 | with open(get_conv_log_filename(), "a") as fout: 276 | data = { 277 | "tstamp": round(finish_tstamp, 4), 278 | "type": "chat", 279 | "model": model_name, 280 | "start": round(start_tstamp, 4), 281 | "finish": round(finish_tstamp, 4), 282 | "state": state.dict(), 283 | "images": all_image_hash, 284 | "ip": request.client.host, 285 | } 286 | fout.write(json.dumps(data) + "\n") 287 | 288 | 289 | title_markdown = (""" 290 | # Cerule! 291 | """) 292 | 293 | tos_markdown = (""" 294 | ### Terms of use 295 | lol 296 | """) 297 | 298 | learn_more_markdown = (""" 299 | ### License 300 | non-commercial use only, Please contact us on X/Linkedin @tensoic if you find any potential violation. 301 | """) 302 | 303 | block_css = """ 304 | 305 | #buttons button { 306 | min-width: min(120px,100%); 307 | } 308 | 309 | """ 310 | 311 | 312 | def trigger_download(doc_path): 313 | return doc_path 314 | 315 | 316 | def build_demo(embed_mode): 317 | textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) 318 | with gr.Blocks(title="cerule", theme=gr.themes.Default(), css=block_css) as demo: 319 | state = gr.State() 320 | 321 | if not embed_mode: 322 | gr.Markdown(title_markdown) 323 | 324 | with gr.Row(): 325 | with gr.Column(scale=4): 326 | with gr.Row(elem_id="model_selector_row"): 327 | model_selector = gr.Dropdown( 328 | choices=models, 329 | value=models[0] if len(models) > 0 else "", 330 | interactive=True, 331 | show_label=False, 332 | container=False, 333 | allow_custom_value=True 334 | ) 335 | 336 | imagebox = gr.Image(type="pil") 337 | image_process_mode = gr.Radio( 338 | ["Crop", "Resize", "Pad", "Default"], 339 | value="Default", 340 | label="Preprocess for non-square image", visible=False) 341 | 342 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 343 | gr.Examples(examples=[ 344 | [f"{cur_dir}/examples/example_1.png", "What is the astronaut holding in his hand?"], 345 | [f"{cur_dir}/examples/example_2.png", "Why is the image funny?"], 346 | ], inputs=[imagebox, textbox]) 347 | 348 | with gr.Accordion("Parameters", open=False) as parameter_row: 349 | temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, 350 | label="Temperature", ) 351 | top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P", ) 352 | max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, 353 | label="Max output tokens", ) 354 | 355 | file_output = gr.components.File(label="Download Document", visible=True, 356 | elem_id="file") # , visible=True,elem_id="file-output" 357 | 358 | with gr.Column(scale=8): 359 | chatbot = gr.Chatbot(elem_id="chatbot", label="cerule Bot", height=550) 360 | with gr.Row(): 361 | with gr.Column(scale=8): 362 | textbox.render() 363 | with gr.Column(scale=1, min_width=50): 364 | submit_btn = gr.Button(value="Send", variant="primary") 365 | 366 | with gr.Row(elem_id="buttons") as button_row: 367 | # stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False) 368 | regenerate_btn = gr.Button(value="🔁 Regenerate", interactive=False) 369 | clear_btn = gr.Button(value="🚮 Clear", interactive=False) 370 | save_conversation_btn = gr.Button(value="🗃️ Save Conversation", interactive=False) 371 | 372 | if not embed_mode: 373 | gr.Markdown(tos_markdown) 374 | gr.Markdown(learn_more_markdown) 375 | url_params = gr.JSON(visible=False) 376 | 377 | # Register listeners 378 | btn_list = [regenerate_btn, clear_btn, save_conversation_btn] 379 | 380 | regenerate_btn.click( 381 | regenerate, 382 | [state, image_process_mode], 383 | [state, chatbot, textbox, imagebox] + btn_list, 384 | queue=False 385 | ).then( 386 | http_bot, 387 | [state, model_selector, temperature, top_p, max_output_tokens], 388 | [state, chatbot] + btn_list 389 | ) 390 | 391 | clear_btn.click( 392 | clear_history, 393 | None, 394 | [state, chatbot, textbox, imagebox] + btn_list, 395 | queue=False 396 | ) 397 | 398 | save_conversation_btn.click( 399 | save_conversation, 400 | inputs=[state], 401 | outputs=file_output 402 | ) 403 | 404 | textbox.submit( 405 | add_text, 406 | [state, textbox, imagebox, image_process_mode], 407 | [state, chatbot, textbox, imagebox] + btn_list, 408 | queue=False 409 | ).then( 410 | http_bot, 411 | [state, model_selector, temperature, top_p, max_output_tokens], 412 | [state, chatbot] + btn_list 413 | ) 414 | 415 | submit_btn.click( 416 | add_text, 417 | [state, textbox, imagebox, image_process_mode], 418 | [state, chatbot, textbox, imagebox] + btn_list, 419 | queue=False 420 | ).then( 421 | http_bot, 422 | [state, model_selector, temperature, top_p, max_output_tokens], 423 | [state, chatbot] + btn_list 424 | ) 425 | 426 | if args.model_list_mode == "once": 427 | demo.load( 428 | load_demo, 429 | [url_params], 430 | [state, model_selector], 431 | _js=get_window_url_params, 432 | queue=False 433 | ) 434 | elif args.model_list_mode == "reload": 435 | demo.load( 436 | load_demo_refresh_model_list, 437 | None, 438 | [state, model_selector], 439 | queue=False 440 | ) 441 | else: 442 | raise ValueError(f"Unknown model list mode: {args.model_list_mode}") 443 | 444 | return demo 445 | 446 | 447 | if __name__ == "__main__": 448 | parser = argparse.ArgumentParser() 449 | parser.add_argument("--host", type=str, default="127.0.0.1") 450 | parser.add_argument("--port", type=int) 451 | parser.add_argument("--controller-url", type=str, default="http://localhost:21001") 452 | parser.add_argument("--concurrency-count", type=int, default=10) 453 | parser.add_argument("--model-list-mode", type=str, default="once", 454 | choices=["once", "reload"]) 455 | parser.add_argument("--share", action="store_true") 456 | parser.add_argument("--moderate", action="store_true") 457 | parser.add_argument("--embed", action="store_true") 458 | args = parser.parse_args() 459 | logger.info(f"args: {args}") 460 | 461 | models = get_model_list() 462 | logger.info(args) 463 | demo = build_demo(args.embed) 464 | demo.launch( 465 | server_name=args.host, 466 | server_port=args.port, 467 | share=args.share, 468 | debug=True, 469 | max_threads=10 470 | ) 471 | -------------------------------------------------------------------------------- /cerule/serve/model_worker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import asyncio 3 | import json 4 | import time 5 | import threading 6 | import uuid 7 | import requests 8 | import torch 9 | import uvicorn 10 | import transformers 11 | 12 | from fastapi import FastAPI, Request, BackgroundTasks 13 | from fastapi.responses import StreamingResponse 14 | from functools import partial 15 | from transformers import TextIteratorStreamer 16 | from threading import Thread 17 | 18 | from cerule.constants import WORKER_HEART_BEAT_INTERVAL 19 | from cerule.util.utils import (build_logger, server_error_msg, pretty_print_semaphore) 20 | from cerule.model.builder import load_pretrained_model 21 | from cerule.util.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, get_model_name_from_path, \ 22 | KeywordsStoppingCriteria 23 | from cerule.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN 24 | 25 | GB = 1 << 30 26 | 27 | worker_id = str(uuid.uuid4())[:6] 28 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log") 29 | global_counter = 0 30 | 31 | model_semaphore = None 32 | 33 | 34 | def heart_beat_worker(controller): 35 | while True: 36 | time.sleep(WORKER_HEART_BEAT_INTERVAL) 37 | controller.send_heart_beat() 38 | 39 | 40 | class ModelWorker: 41 | def __init__(self, controller_addr, worker_addr, 42 | worker_id, no_register, 43 | model_path, model_base, model_name, model_type, 44 | load_8bit, load_4bit, device): 45 | self.controller_addr = controller_addr 46 | self.worker_addr = worker_addr 47 | self.worker_id = worker_id 48 | if model_path.endswith("/"): 49 | model_path = model_path[:-1] 50 | if model_name is None: 51 | self.model_name = get_model_name_from_path(model_path) 52 | else: 53 | self.model_name = model_name 54 | 55 | self.device = device 56 | logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...") 57 | transformers.logging.set_verbosity_error() 58 | transformers.logging.disable_progress_bar() 59 | self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( 60 | model_path, model_base, self.model_name, model_type, load_8bit, load_4bit, device=self.device) 61 | self.is_multimodal = True 62 | 63 | if not no_register: 64 | self.register_to_controller() 65 | self.heart_beat_thread = threading.Thread( 66 | target=heart_beat_worker, args=(self,)) 67 | self.heart_beat_thread.start() 68 | 69 | def register_to_controller(self): 70 | logger.info("Register to controller") 71 | 72 | url = self.controller_addr + "/register_worker" 73 | data = { 74 | "worker_name": self.worker_addr, 75 | "check_heart_beat": True, 76 | "worker_status": self.get_status() 77 | } 78 | r = requests.post(url, json=data) 79 | assert r.status_code == 200 80 | 81 | def send_heart_beat(self): 82 | logger.info(f"Send heart beat. Models: {[self.model_name]}. " 83 | f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " 84 | f"global_counter: {global_counter}") 85 | 86 | url = self.controller_addr + "/receive_heart_beat" 87 | 88 | while True: 89 | try: 90 | ret = requests.post(url, json={ 91 | "worker_name": self.worker_addr, 92 | "queue_length": self.get_queue_length()}, timeout=5) 93 | exist = ret.json()["exist"] 94 | break 95 | except requests.exceptions.RequestException as e: 96 | logger.error(f"heart beat error: {e}") 97 | time.sleep(5) 98 | 99 | if not exist: 100 | self.register_to_controller() 101 | 102 | def get_queue_length(self): 103 | if model_semaphore is None: 104 | return 0 105 | else: 106 | return args.limit_model_concurrency - model_semaphore._value + (len( 107 | model_semaphore._waiters) if model_semaphore._waiters is not None else 0) 108 | 109 | def get_status(self): 110 | return { 111 | "model_names": [self.model_name], 112 | "speed": 1, 113 | "queue_length": self.get_queue_length(), 114 | } 115 | 116 | @torch.inference_mode() 117 | def generate_stream(self, params): 118 | tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor 119 | 120 | prompt = params["prompt"] 121 | ori_prompt = prompt 122 | images = params.get("images", None) 123 | num_image_tokens = 0 124 | if images is not None and len(images) > 0 and self.is_multimodal: 125 | if len(images) > 0: 126 | if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN): 127 | raise ValueError("Number of images does not match number of tokens in prompt") 128 | 129 | images = [load_image_from_base64(image) for image in images] 130 | images = process_images(images, image_processor, model.config) 131 | print(f"----> process_images {images}") 132 | print(f"----> process_images sum {torch.sum(images)}") 133 | if type(images) is list: 134 | images = [image.to(self.model.device, dtype=model.dtype) for image in images] 135 | else: 136 | images = images.to(self.model.device, dtype=model.dtype) 137 | 138 | replace_token = DEFAULT_IMAGE_TOKEN 139 | prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token) 140 | 141 | num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches 142 | else: 143 | images = None 144 | image_args = {"images": images} 145 | else: 146 | images = None 147 | image_args = {} 148 | 149 | temperature = float(params.get("temperature", 1.0)) 150 | top_p = float(params.get("top_p", 1.0)) 151 | max_context_length = getattr(model.config, 'max_position_embeddings', 2048) 152 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) 153 | stop_str = params.get("stop", None) 154 | do_sample = True if temperature > 0.001 else False 155 | 156 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to( 157 | self.device) 158 | keywords = [stop_str] 159 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 160 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) 161 | 162 | max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) 163 | 164 | if max_new_tokens < 1: 165 | yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", 166 | "error_code": 0}).encode() + b"\0" 167 | return 168 | print("max_new_tokens", max_new_tokens) 169 | print("start!") 170 | 171 | thread = Thread(target=model.generate, kwargs=dict( 172 | inputs=input_ids, 173 | do_sample=do_sample, 174 | temperature=temperature, 175 | top_p=top_p, 176 | max_new_tokens=max_new_tokens, 177 | streamer=streamer, 178 | stopping_criteria=[stopping_criteria], 179 | use_cache=True, 180 | **image_args 181 | )) 182 | thread.start() 183 | 184 | generated_text = ori_prompt 185 | for new_text in streamer: 186 | if generated_text and not generated_text.endswith(' '): 187 | generated_text += ' ' 188 | generated_text += new_text 189 | if generated_text.endswith(stop_str): 190 | generated_text = generated_text[:-len(stop_str)] 191 | logger.info(f"new_text: {new_text}") 192 | yield json.dumps({"text": generated_text, "error_code": 0}).encode() + b"\0" 193 | 194 | def generate_stream_gate(self, params): 195 | try: 196 | for x in self.generate_stream(params): 197 | yield x 198 | except ValueError as e: 199 | print("Caught ValueError:", e) 200 | ret = { 201 | "text": server_error_msg, 202 | "error_code": 1, 203 | } 204 | yield json.dumps(ret).encode() + b"\0" 205 | except torch.cuda.CudaError as e: 206 | print("Caught torch.cuda.CudaError:", e) 207 | ret = { 208 | "text": server_error_msg, 209 | "error_code": 1, 210 | } 211 | yield json.dumps(ret).encode() + b"\0" 212 | except Exception as e: 213 | print("Caught Unknown Error", e) 214 | ret = { 215 | "text": server_error_msg, 216 | "error_code": 1, 217 | } 218 | yield json.dumps(ret).encode() + b"\0" 219 | 220 | 221 | app = FastAPI() 222 | 223 | 224 | def release_model_semaphore(fn=None): 225 | model_semaphore.release() 226 | if fn is not None: 227 | fn() 228 | 229 | 230 | @app.post("/worker_generate_stream") 231 | async def generate_stream(request: Request): 232 | global model_semaphore, global_counter 233 | global_counter += 1 234 | params = await request.json() 235 | 236 | if model_semaphore is None: 237 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) 238 | await model_semaphore.acquire() 239 | worker.send_heart_beat() 240 | generator = worker.generate_stream_gate(params) 241 | background_tasks = BackgroundTasks() 242 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) 243 | return StreamingResponse(generator, background=background_tasks) 244 | 245 | 246 | @app.post("/worker_get_status") 247 | async def get_status(request: Request): 248 | return worker.get_status() 249 | 250 | 251 | if __name__ == "__main__": 252 | parser = argparse.ArgumentParser() 253 | parser.add_argument("--host", type=str, default="localhost") 254 | parser.add_argument("--port", type=int, default=21002) 255 | parser.add_argument("--worker-address", type=str, 256 | default="http://localhost:21002") 257 | parser.add_argument("--controller-address", type=str, 258 | default="http://localhost:21001") 259 | parser.add_argument("--model-path", type=str, default=None) 260 | parser.add_argument("--model-base", type=str, default=None) 261 | parser.add_argument("--model-name", type=str) 262 | parser.add_argument("--model-type", type=str, default=None) 263 | parser.add_argument("--device", type=str, default="cuda") 264 | parser.add_argument("--multi-modal", action="store_true", 265 | help="Multimodal mode is automatically detected with model name.") 266 | parser.add_argument("--limit-model-concurrency", type=int, default=5) 267 | parser.add_argument("--stream-interval", type=int, default=1) 268 | parser.add_argument("--no-register", action="store_true") 269 | parser.add_argument("--load-8bit", action="store_true") 270 | parser.add_argument("--load-4bit", action="store_true") 271 | args = parser.parse_args() 272 | logger.info(f"args: {args}") 273 | 274 | if args.multi_modal: 275 | logger.warning("Multimodal mode is automatically detected with model name.") 276 | 277 | worker = ModelWorker(args.controller_address, 278 | args.worker_address, 279 | worker_id, 280 | args.no_register, 281 | args.model_path, 282 | args.model_base, 283 | args.model_name, 284 | args.model_type, 285 | args.load_8bit, 286 | args.load_4bit, 287 | args.device) 288 | 289 | log_config = uvicorn.config.LOGGING_CONFIG 290 | log_config['handlers']['default']['stream'] = 'ext://sys.stdout' 291 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") 292 | -------------------------------------------------------------------------------- /cerule/serve/register_worker.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import requests 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--controller-address", type=str) 7 | parser.add_argument("--worker-name", type=str) 8 | parser.add_argument("--check-heart-beat", action="store_true") 9 | args = parser.parse_args() 10 | 11 | url = args.controller_address + "/register_worker" 12 | data = { 13 | "worker_name": args.worker_name, 14 | "check_heart_beat": args.check_heart_beat, 15 | "worker_status": None, 16 | } 17 | r = requests.post(url, json=data) 18 | assert r.status_code == 200 19 | -------------------------------------------------------------------------------- /cerule/train/cerule_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from torch.utils.data import Sampler 5 | from torch import nn 6 | from transformers import Trainer 7 | from transformers.trainer import is_sagemaker_mp_enabled, get_parameter_names, has_length, ALL_LAYERNORM_LAYERS, logger 8 | 9 | from typing import List, Optional 10 | 11 | 12 | def maybe_zero_3(param, ignore_status=False, name=None): 13 | from deepspeed import zero 14 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 15 | if hasattr(param, "ds_id"): 16 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 17 | if not ignore_status: 18 | print(name, 'no ignore status') 19 | with zero.GatheredParameters([param]): 20 | param = param.data.detach().cpu().clone() 21 | else: 22 | param = param.detach().cpu().clone() 23 | return param 24 | 25 | 26 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 27 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 28 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 29 | return to_return 30 | 31 | 32 | def split_to_even_chunks(indices, lengths, num_chunks): 33 | """ 34 | Split a list of indices into `chunks` chunks of roughly equal lengths. 35 | """ 36 | 37 | if len(indices) % num_chunks != 0: 38 | return [indices[i::num_chunks] for i in range(num_chunks)] 39 | 40 | num_indices_per_chunk = len(indices) // num_chunks 41 | 42 | chunks = [[] for _ in range(num_chunks)] 43 | chunks_lengths = [0 for _ in range(num_chunks)] 44 | for index in indices: 45 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 46 | chunks[shortest_chunk].append(index) 47 | chunks_lengths[shortest_chunk] += lengths[index] 48 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 49 | chunks_lengths[shortest_chunk] = float("inf") 50 | 51 | return chunks 52 | 53 | 54 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): 55 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 56 | assert all(l != 0 for l in lengths), "Should not have zero length." 57 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): 58 | # all samples are in the same modality 59 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) 60 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) 61 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) 62 | 63 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] 64 | lang_shuffle = [lang_indices[i] for i in 65 | get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] 66 | megabatch_size = world_size * batch_size 67 | mm_megabatches = [mm_shuffle[i: i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] 68 | lang_megabatches = [lang_shuffle[i: i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] 69 | 70 | last_mm = mm_megabatches[-1] 71 | last_lang = lang_megabatches[-1] 72 | additional_batch = last_mm + last_lang 73 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] 74 | megabatch_indices = torch.randperm(len(megabatches), generator=generator) 75 | megabatches = [megabatches[i] for i in megabatch_indices] 76 | 77 | if len(additional_batch) > 0: 78 | megabatches.append(sorted(additional_batch)) 79 | 80 | return [i for megabatch in megabatches for i in megabatch] 81 | 82 | 83 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 84 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 85 | indices = torch.randperm(len(lengths), generator=generator) 86 | megabatch_size = world_size * batch_size 87 | megabatches = [indices[i: i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 88 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] 89 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] 90 | 91 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 92 | 93 | 94 | class LengthGroupedSampler(Sampler): 95 | r""" 96 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 97 | keeping a bit of randomness. 98 | """ 99 | 100 | def __init__( 101 | self, 102 | batch_size: int, 103 | world_size: int, 104 | lengths: Optional[List[int]] = None, 105 | generator=None, 106 | group_by_modality: bool = False, 107 | ): 108 | if lengths is None: 109 | raise ValueError("Lengths must be provided.") 110 | 111 | self.batch_size = batch_size 112 | self.world_size = world_size 113 | self.lengths = lengths 114 | self.generator = generator 115 | self.group_by_modality = group_by_modality 116 | 117 | def __len__(self): 118 | return len(self.lengths) 119 | 120 | def __iter__(self): 121 | if self.group_by_modality: 122 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, 123 | generator=self.generator) 124 | else: 125 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, 126 | generator=self.generator) 127 | return iter(indices) 128 | 129 | 130 | class CeruleTrainer(Trainer): 131 | 132 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 133 | if self.train_dataset is None or not has_length(self.train_dataset): 134 | return None 135 | 136 | if self.args.group_by_modality_length: 137 | lengths = self.train_dataset.modality_lengths 138 | return LengthGroupedSampler( 139 | self.args.train_batch_size, 140 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 141 | lengths=lengths, 142 | group_by_modality=True, 143 | ) 144 | else: 145 | return super()._get_train_sampler() 146 | 147 | def create_optimizer(self): 148 | """ 149 | Setup the optimizer. 150 | 151 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 152 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 153 | """ 154 | if is_sagemaker_mp_enabled(): 155 | return super().create_optimizer() 156 | 157 | opt_model = self.model 158 | 159 | if self.optimizer is None: 160 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 161 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 162 | if self.args.mm_projector_lr is not None: 163 | projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_projector" in name] 164 | optimizer_grouped_parameters = [ 165 | { 166 | "params": [ 167 | p for n, p in opt_model.named_parameters() if 168 | (n in decay_parameters and n not in projector_parameters and p.requires_grad) 169 | ], 170 | "weight_decay": self.args.weight_decay, 171 | }, 172 | { 173 | "params": [ 174 | p for n, p in opt_model.named_parameters() if 175 | (n not in decay_parameters and n not in projector_parameters and p.requires_grad) 176 | ], 177 | "weight_decay": 0.0, 178 | }, 179 | { 180 | "params": [ 181 | p for n, p in opt_model.named_parameters() if 182 | (n in decay_parameters and n in projector_parameters and p.requires_grad) 183 | ], 184 | "weight_decay": self.args.weight_decay, 185 | "lr": self.args.mm_projector_lr, 186 | }, 187 | { 188 | "params": [ 189 | p for n, p in opt_model.named_parameters() if 190 | (n not in decay_parameters and n in projector_parameters and p.requires_grad) 191 | ], 192 | "weight_decay": 0.0, 193 | "lr": self.args.mm_projector_lr, 194 | }, 195 | ] 196 | else: 197 | optimizer_grouped_parameters = [ 198 | { 199 | "params": [ 200 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) 201 | ], 202 | "weight_decay": self.args.weight_decay, 203 | }, 204 | { 205 | "params": [ 206 | p for n, p in opt_model.named_parameters() if 207 | (n not in decay_parameters and p.requires_grad) 208 | ], 209 | "weight_decay": 0.0, 210 | }, 211 | ] 212 | 213 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 214 | 215 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 216 | if optimizer_cls.__name__ == "Adam8bit": 217 | import bitsandbytes 218 | 219 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 220 | 221 | skipped = 0 222 | for module in opt_model.modules(): 223 | if isinstance(module, nn.Embedding): 224 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) 225 | logger.info(f"skipped {module}: {skipped / 2 ** 20}M params") 226 | manager.register_module_override(module, "weight", {"optim_bits": 32}) 227 | logger.debug(f"bitsandbytes: will optimize {module} in fp32") 228 | logger.info(f"skipped: {skipped / 2 ** 20}M params") 229 | 230 | return self.optimizer 231 | 232 | def _save_checkpoint(self, model, trial, metrics=None): 233 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 234 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 235 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 236 | 237 | run_dir = self._get_output_dir(trial=trial) 238 | output_dir = os.path.join(run_dir, checkpoint_folder) 239 | 240 | # Only save Adapter 241 | keys_to_match = ['mm_projector', 'vision_resampler'] 242 | if getattr(self.args, "use_im_start_end", False): 243 | keys_to_match.extend(['embed_tokens', 'embed_in']) 244 | 245 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 246 | 247 | if self.args.local_rank == 0 or self.args.local_rank == -1: 248 | self.model.config.save_pretrained(output_dir) 249 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 250 | else: 251 | super(CeruleTrainer, self)._save_checkpoint(model, trial, metrics) 252 | 253 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 254 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 255 | pass 256 | else: 257 | super(CeruleTrainer, self)._save(output_dir, state_dict) 258 | -------------------------------------------------------------------------------- /cerule/train/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | import logging 4 | import pathlib 5 | from typing import Optional, List, Sequence 6 | 7 | import torch 8 | 9 | import transformers 10 | 11 | from cerule.train.cerule_trainer import CeruleTrainer 12 | 13 | from cerule.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 14 | 15 | from cerule import conversation as conversation_lib 16 | from cerule.model import * 17 | from cerule.util.data_utils import make_supervised_data_module, DataArguments, preprocess_gemma, preprocess_multimodal 18 | from cerule.util.mm_utils import tokenizer_image_token 19 | 20 | 21 | local_rank = None 22 | 23 | 24 | def rank0_print(*args): 25 | if local_rank == 0: 26 | print(*args) 27 | 28 | 29 | @dataclass 30 | class ModelArguments: 31 | model_name_or_path: Optional[str] = field(default=None) 32 | model_type: Optional[str] = field(default=None) 33 | version: Optional[str] = field(default=None) 34 | freeze_backbone: bool = field(default=False) 35 | tune_mm_mlp_adapter: bool = field(default=False) 36 | vision_tower: Optional[str] = field(default=None) 37 | pretrain_mm_mlp_adapter: Optional[str] = field(default=None) 38 | mm_projector_type: Optional[str] = field(default='mlp2x_gelu') 39 | mm_vision_select_layer: Optional[int] = field(default=-1) 40 | mm_use_im_start_end: bool = field(default=False) 41 | mm_use_im_patch_token: bool = field(default=True) 42 | 43 | 44 | @dataclass 45 | class TrainingArguments(transformers.TrainingArguments): 46 | cache_dir: Optional[str] = field(default=None) 47 | optim: str = field(default="adamw_torch") 48 | remove_unused_columns: bool = field(default=False) 49 | freeze_mm_mlp_adapter: bool = field(default=False) 50 | mpt_attn_impl: Optional[str] = field(default="triton") 51 | model_max_length: int = field( 52 | default=512, 53 | metadata={ 54 | "help": 55 | "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 56 | }, 57 | ) 58 | double_quant: bool = field( 59 | default=True, 60 | metadata={"help": "Compress the quantization statistics through double quantization."} 61 | ) 62 | quant_type: str = field( 63 | default="nf4", 64 | metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} 65 | ) 66 | bits: int = field( 67 | default=16, 68 | metadata={"help": "How many bits to use."} 69 | ) 70 | lora_enable: bool = False 71 | lora_r: int = 64 72 | lora_alpha: int = 16 73 | lora_dropout: float = 0.05 74 | lora_weight_path: str = "" 75 | lora_bias: str = "none" 76 | mm_projector_lr: Optional[float] = None 77 | group_by_modality_length: bool = field(default=False) 78 | 79 | 80 | def maybe_zero_3(param, ignore_status=False, name=None): 81 | from deepspeed import zero 82 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 83 | if hasattr(param, "ds_id"): 84 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 85 | if not ignore_status: 86 | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") 87 | with zero.GatheredParameters([param]): 88 | param = param.data.detach().cpu().clone() 89 | else: 90 | param = param.detach().cpu().clone() 91 | return param 92 | 93 | 94 | # Borrowed from peft.util.get_peft_model_state_dict 95 | def get_peft_state_maybe_zero_3(named_params, bias): 96 | if bias == "none": 97 | to_return = {k: t for k, t in named_params if "lora_" in k} 98 | elif bias == "all": 99 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 100 | elif bias == "lora_only": 101 | to_return = {} 102 | maybe_lora_bias = {} 103 | lora_bias_names = set() 104 | for k, t in named_params: 105 | if "lora_" in k: 106 | to_return[k] = t 107 | bias_name = k.split("lora_")[0] + "bias" 108 | lora_bias_names.add(bias_name) 109 | elif "bias" in k: 110 | maybe_lora_bias[k] = t 111 | for k, t in maybe_lora_bias: 112 | if bias_name in lora_bias_names: 113 | to_return[bias_name] = t 114 | else: 115 | raise NotImplementedError 116 | to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} 117 | return to_return 118 | 119 | 120 | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): 121 | to_return = {k: t for k, t in named_params if "lora_" not in k} 122 | if require_grad_only: 123 | to_return = {k: t for k, t in to_return.items() if t.requires_grad} 124 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} 125 | return to_return 126 | 127 | 128 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 129 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 130 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} 131 | return to_return 132 | 133 | 134 | def find_all_linear_names(model): 135 | cls = torch.nn.Linear 136 | lora_module_names = set() 137 | multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler'] 138 | for name, module in model.named_modules(): 139 | if any(mm_keyword in name for mm_keyword in multimodal_keywords): 140 | continue 141 | if isinstance(module, cls): 142 | names = name.split('.') 143 | lora_module_names.add(names[0] if len(names) == 1 else names[-1]) 144 | 145 | if 'lm_head' in lora_module_names: # needed for 16-bit 146 | lora_module_names.remove('lm_head') 147 | return list(lora_module_names) 148 | 149 | 150 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, 151 | output_dir: str): 152 | """Collects the state dict and dump to disk.""" 153 | 154 | if getattr(trainer.args, "tune_mm_mlp_adapter", False): 155 | # Only save Adapter 156 | keys_to_match = ['mm_projector'] 157 | if getattr(trainer.args, "use_im_start_end", False): 158 | keys_to_match.extend(['embed_tokens', 'embed_in']) 159 | 160 | weight_to_save = get_mm_adapter_state_maybe_zero_3(trainer.model.named_parameters(), keys_to_match) 161 | trainer.model.config.save_pretrained(output_dir) 162 | 163 | current_folder = output_dir.split('/')[-1] 164 | parent_folder = os.path.dirname(output_dir) 165 | if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: 166 | if current_folder.startswith('checkpoint-'): 167 | mm_projector_folder = os.path.join(parent_folder, "mm_projector") 168 | os.makedirs(mm_projector_folder, exist_ok=True) 169 | torch.save(weight_to_save, os.path.join(mm_projector_folder, f'{current_folder}.bin')) 170 | else: 171 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 172 | return 173 | 174 | if trainer.deepspeed: 175 | torch.cuda.synchronize() 176 | trainer.save_model(output_dir) 177 | return 178 | 179 | state_dict = trainer.model.state_dict() 180 | if trainer.args.should_save: 181 | cpu_state_dict = { 182 | key: value.cpu() 183 | for key, value in state_dict.items() 184 | } 185 | del state_dict 186 | trainer._save(output_dir, state_dict=cpu_state_dict) # noqa 187 | 188 | 189 | def train(): 190 | global local_rank 191 | 192 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 193 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 194 | local_rank = training_args.local_rank 195 | compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) 196 | 197 | bnb_model_from_pretrained_args = {} 198 | if training_args.bits in [4, 8]: 199 | from transformers import BitsAndBytesConfig 200 | bnb_model_from_pretrained_args.update(dict( 201 | device_map={"": training_args.device}, 202 | load_in_4bit=training_args.bits == 4, 203 | load_in_8bit=training_args.bits == 8, 204 | quantization_config=BitsAndBytesConfig( 205 | load_in_4bit=training_args.bits == 4, 206 | load_in_8bit=training_args.bits == 8, 207 | llm_int8_skip_modules=["mm_projector"], 208 | llm_int8_threshold=6.0, 209 | llm_int8_has_fp16_weight=False, 210 | bnb_4bit_compute_dtype=compute_dtype, 211 | bnb_4bit_use_double_quant=training_args.double_quant, 212 | bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} 213 | ) 214 | )) 215 | 216 | assert model_args.vision_tower is not None 217 | if model_args.model_type == 'phi-1.5' or model_args.model_type == 'phi-2': 218 | tokenizer = transformers.AutoTokenizer.from_pretrained( 219 | model_args.model_name_or_path, 220 | cache_dir=training_args.cache_dir, 221 | model_max_length=training_args.model_max_length, 222 | padding_side="right", 223 | use_fast=False, 224 | ) 225 | elif model_args.model_type == 'stablelm-2': 226 | tokenizer = transformers.AutoTokenizer.from_pretrained( 227 | model_args.model_name_or_path, 228 | cache_dir=training_args.cache_dir, 229 | model_max_length=training_args.model_max_length, 230 | padding_side="right", 231 | use_fast=False, 232 | trust_remote_code=True 233 | ) 234 | elif model_args.model_type == 'gemma': 235 | tokenizer = transformers.AutoTokenizer.from_pretrained( 236 | model_args.model_name_or_path, 237 | cache_dir=training_args.cache_dir, 238 | model_max_length=training_args.model_max_length, 239 | padding_side="right", 240 | use_fast=False, 241 | trust_remote_code=True 242 | ) 243 | 244 | if model_args.model_type == 'phi-1.5' or model_args.model_type == 'phi-2': 245 | model = CerulePhiForCausalLM.from_pretrained( 246 | model_args.model_name_or_path, 247 | cache_dir=training_args.cache_dir, 248 | bos_token_id=tokenizer.bos_token_id, 249 | eos_token_id=tokenizer.eos_token_id, 250 | **bnb_model_from_pretrained_args 251 | ) 252 | elif model_args.model_type == 'stablelm-2': 253 | model = CeruleStableLMForCausalLM.from_pretrained( 254 | model_args.model_name_or_path, 255 | cache_dir=training_args.cache_dir, 256 | bos_token_id=tokenizer.bos_token_id, 257 | eos_token_id=tokenizer.eos_token_id, 258 | **bnb_model_from_pretrained_args 259 | ) 260 | elif model_args.model_type == 'gemma': 261 | model = CeruleGemmaForCausalLM.from_pretrained( 262 | model_args.model_name_or_path, 263 | cache_dir=training_args.cache_dir, 264 | attn_implementation=attn_implementation, 265 | torch_dtype=(torch.bfloat16 if training_args.bf16 else None), 266 | **bnb_model_from_pretrained_args 267 | ) 268 | else: 269 | raise ValueError(f"Unknown Model Type {model_args.model_type}") 270 | 271 | model.config.use_cache = False 272 | 273 | if model_args.freeze_backbone: 274 | model.model.requires_grad_(False) 275 | 276 | if training_args.bits in [4, 8]: 277 | from peft import prepare_model_for_kbit_training 278 | model.config.torch_dtype = ( 279 | torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) 280 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) 281 | 282 | if training_args.gradient_checkpointing: 283 | if hasattr(model, "enable_input_require_grads"): 284 | model.enable_input_require_grads() 285 | else: 286 | def make_inputs_require_grad(module, input, output): 287 | output.requires_grad_(True) 288 | 289 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 290 | 291 | if training_args.lora_enable: 292 | from peft import LoraConfig, get_peft_model 293 | lora_config = LoraConfig( 294 | r=training_args.lora_r, 295 | lora_alpha=training_args.lora_alpha, 296 | target_modules=find_all_linear_names(model), 297 | lora_dropout=training_args.lora_dropout, 298 | bias=training_args.lora_bias, 299 | task_type="CAUSAL_LM", 300 | ) 301 | if training_args.bits == 16: 302 | if training_args.bf16: 303 | model.to(torch.bfloat16) 304 | if training_args.fp16: 305 | model.to(torch.float16) 306 | rank0_print("Adding LoRA adapters...") 307 | model = get_peft_model(model, lora_config) 308 | 309 | if model_args.version in conversation_lib.conv_templates: 310 | conversation_lib.default_conversation = conversation_lib.conv_templates[model_args.version] 311 | else: 312 | conversation_lib.default_conversation = conversation_lib.conv_templates["default"] 313 | 314 | model.get_model().initialize_vision_modules(model_args=model_args) 315 | 316 | vision_tower = model.get_vision_tower() 317 | vision_tower.to(dtype=torch.bfloat16 if training_args.bf16 else torch.float16, device=training_args.device) 318 | 319 | data_args.image_processor = vision_tower.image_processor 320 | 321 | model.config.image_aspect_ratio = data_args.image_aspect_ratio 322 | model.config.tokenizer_padding_side = tokenizer.padding_side 323 | model.config.tokenizer_model_max_length = tokenizer.model_max_length 324 | 325 | model.config.tune_mm_mlp_adapter = training_args.tune_mm_mlp_adapter = model_args.tune_mm_mlp_adapter 326 | if model_args.tune_mm_mlp_adapter: 327 | model.requires_grad_(False) 328 | for p in model.get_model().mm_projector.parameters(): 329 | p.requires_grad = True 330 | 331 | model.config.freeze_mm_mlp_adapter = training_args.freeze_mm_mlp_adapter 332 | if training_args.freeze_mm_mlp_adapter: 333 | for p in model.get_model().mm_projector.parameters(): 334 | p.requires_grad = False 335 | 336 | if training_args.bits in [4, 8]: 337 | model.get_model().mm_projector.to(dtype=compute_dtype, device=training_args.device) 338 | 339 | model.config.mm_projector_lr = training_args.mm_projector_lr 340 | 341 | if training_args.bits in [4, 8]: 342 | from peft.tuners.lora import LoraLayer 343 | for name, module in model.named_modules(): 344 | if isinstance(module, LoraLayer): 345 | if training_args.bf16: 346 | module = module.to(torch.bfloat16) 347 | if 'norm' in name: 348 | module = module.to(torch.float32) 349 | if 'lm_head' in name or 'embed_tokens' in name: 350 | if hasattr(module, 'weight'): 351 | if training_args.bf16 and module.weight.dtype == torch.float32: 352 | module = module.to(torch.bfloat16) 353 | 354 | data_module = make_supervised_data_module(tokenizer=tokenizer, 355 | data_args=data_args) 356 | trainer = CeruleTrainer(model=model, 357 | tokenizer=tokenizer, 358 | args=training_args, 359 | **data_module) 360 | 361 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 362 | trainer.train(resume_from_checkpoint=True) 363 | else: 364 | trainer.train() 365 | trainer.save_state() 366 | 367 | model.config.use_cache = True 368 | 369 | if training_args.lora_enable: 370 | state_dict = get_peft_state_maybe_zero_3( 371 | model.named_parameters(), training_args.lora_bias 372 | ) 373 | non_lora_state_dict = get_peft_state_non_lora_maybe_zero_3( 374 | model.named_parameters() 375 | ) 376 | if training_args.local_rank == 0 or training_args.local_rank == -1: 377 | model.config.save_pretrained(training_args.output_dir) 378 | model.save_pretrained(training_args.output_dir, state_dict=state_dict) 379 | torch.save(non_lora_state_dict, os.path.join(training_args.output_dir, 'non_lora_trainables.bin')) 380 | else: 381 | safe_save_model_for_hf_trainer(trainer=trainer, 382 | output_dir=training_args.output_dir) 383 | 384 | 385 | if __name__ == "__main__": 386 | train(attn_implementation="flash_attention_2") 387 | -------------------------------------------------------------------------------- /cerule/util/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | from dataclasses import dataclass, field 4 | import json 5 | from typing import Dict, Sequence, Optional, List 6 | 7 | import torch 8 | 9 | import transformers 10 | 11 | from cerule.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 12 | from torch.utils.data import Dataset 13 | 14 | from cerule import conversation as conversation_lib 15 | 16 | from cerule.util.mm_utils import tokenizer_image_token 17 | 18 | from PIL import Image 19 | 20 | 21 | @dataclass 22 | class DataArguments: 23 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 24 | lazy_preprocess: bool = False 25 | is_multimodal: bool = True 26 | image_folder: Optional[str] = field(default=None) 27 | image_aspect_ratio: str = field(default=None) 28 | 29 | 30 | def preprocess_multimodal( 31 | sources: Sequence[str], 32 | data_args: DataArguments 33 | ) -> Dict: 34 | is_multimodal = data_args.is_multimodal 35 | if not is_multimodal: 36 | return sources 37 | 38 | for source in sources: 39 | for sentence in source: 40 | if DEFAULT_IMAGE_TOKEN in sentence['value']: 41 | sentence['value'] = sentence['value'].replace(DEFAULT_IMAGE_TOKEN, '').strip() 42 | sentence['value'] = DEFAULT_IMAGE_TOKEN + '\n' + sentence['value'] 43 | sentence['value'] = sentence['value'].strip() 44 | 45 | replace_token = DEFAULT_IMAGE_TOKEN 46 | 47 | sentence["value"] = sentence["value"].replace(DEFAULT_IMAGE_TOKEN, replace_token) 48 | 49 | return sources 50 | 51 | def preprocess_plain( 52 | sources: Sequence[str], 53 | tokenizer: transformers.PreTrainedTokenizer, 54 | ) -> Dict: 55 | # add end signal and concatenate together 56 | conversations = [] 57 | for source in sources: 58 | assert len(source) == 2 59 | assert DEFAULT_IMAGE_TOKEN in source[0]['value'] 60 | source[0]['value'] = DEFAULT_IMAGE_TOKEN 61 | conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep 62 | conversations.append(conversation) 63 | # tokenize conversations 64 | input_ids = [tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] 65 | targets = copy.deepcopy(input_ids) 66 | for target, source in zip(targets, sources): 67 | tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer)) 68 | target[:tokenized_len] = IGNORE_INDEX 69 | 70 | return dict(input_ids=input_ids, labels=targets) 71 | 72 | def preprocess_gemma( 73 | sources: List[List[Dict[str, str]]], 74 | tokenizer: transformers.PreTrainedTokenizer, 75 | has_image: bool = False 76 | ) -> Dict: 77 | conv: conversation_lib.Conversation = conversation_lib.default_conversation.copy() 78 | roles: Dict[str, str] = {"human": conv.roles[0], "gpt": conv.roles[1]} 79 | 80 | # Apply prompt templates 81 | conversations: List[str] = [] 82 | for i, source in enumerate(sources): 83 | if roles[source[0]["from"]] != conv.roles[0]: 84 | # Skip the first one if it is not from human 85 | source: List[Dict[str, str]] = source[1:] 86 | 87 | conv.messages = [] 88 | for j, sentence in enumerate(source): 89 | role: str = roles[sentence["from"]] 90 | assert role == conv.roles[j % 2], f"{i}" 91 | conv.append_message(role, sentence["value"]) 92 | conversations.append(conv.get_prompt()) 93 | 94 | # Tokenize conversations 95 | if has_image: 96 | input_ids: torch.Tensor = torch.stack([tokenizer_image_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) 97 | else: 98 | input_ids: torch.Tensor = tokenizer( 99 | conversations, 100 | return_tensors="pt", 101 | padding="longest", 102 | max_length=tokenizer.model_max_length, 103 | truncation=True, 104 | ).input_ids 105 | 106 | targets: torch.Tensor = input_ids.clone() 107 | assert conv.sep_style == conversation_lib.SeparatorStyle.GEMMA 108 | 109 | # Mask target 110 | sep: str = conv.sep + conv.roles[1] 111 | for conversation, target in zip(conversations, targets): 112 | total_len: int = int(target.ne(tokenizer.pad_token_id).sum()) 113 | 114 | rounds: List[str] = conversation.split(conv.sep) 115 | re_rounds = [] 116 | for conv_idx in range(0, len(rounds), 2): 117 | re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) 118 | 119 | cur_len = 1 # Ignore 120 | target[:cur_len] = IGNORE_INDEX 121 | for i, rou in enumerate(re_rounds): 122 | if rou == "": 123 | break 124 | 125 | parts = rou.split(sep) 126 | if len(parts) != 2: 127 | break 128 | parts[0] += sep # Re-append sep because split on this 129 | # Now "".join(parts)==rou 130 | 131 | if has_image: 132 | round_len = len(tokenizer_image_token(rou, tokenizer)) - 1 # Ignore 133 | instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 # Ignore 134 | else: 135 | round_len = len(tokenizer(rou).input_ids) - 1 # Ignore 136 | instruction_len = len(tokenizer(parts[0]).input_ids) - 1 # Ignore 137 | 138 | round_len += 2 # sep: \n takes 2 tokens 139 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 140 | cur_len += round_len 141 | 142 | target[cur_len:] = IGNORE_INDEX 143 | 144 | if cur_len < tokenizer.model_max_length: 145 | if cur_len != total_len: 146 | target[:] = IGNORE_INDEX 147 | print( 148 | f"warning: tokenization mismatch: {cur_len} vs. {total_len}." 149 | f" (ignored)" 150 | ) 151 | 152 | return dict( 153 | input_ids=input_ids, 154 | labels=targets, 155 | ) 156 | 157 | def preprocess( 158 | sources: Sequence[str], 159 | tokenizer: transformers.PreTrainedTokenizer, 160 | has_image: bool = False 161 | ) -> Dict: 162 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: 163 | return preprocess_plain(sources, tokenizer) 164 | 165 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.GEMMA: 166 | return preprocess_gemma(sources, tokenizer, has_image=has_image) 167 | 168 | if conversation_lib.default_conversation.version == "gemma": 169 | return preprocess_gemma(sources, tokenizer, has_image=has_image) 170 | 171 | 172 | class LazySupervisedDataset(Dataset): 173 | """Dataset for supervised fine-tuning.""" 174 | 175 | def __init__(self, data_path: str, 176 | tokenizer: transformers.PreTrainedTokenizer, 177 | data_args: DataArguments): 178 | super(LazySupervisedDataset, self).__init__() 179 | list_data_dict = json.load(open(data_path, "r")) 180 | 181 | print("Formatting inputs...Skip in lazy mode") 182 | self.tokenizer = tokenizer 183 | self.list_data_dict = list_data_dict 184 | self.data_args = data_args 185 | 186 | def __len__(self): 187 | return len(self.list_data_dict) 188 | 189 | @property 190 | def lengths(self): 191 | length_list = [] 192 | for sample in self.list_data_dict: 193 | img_tokens = 128 if 'image' in sample else 0 194 | length_list.append(sum(len(conv['value'].split()) for conv in sample['conversations']) + img_tokens) 195 | return length_list 196 | 197 | @property 198 | def modality_lengths(self): 199 | length_list = [] 200 | for sample in self.list_data_dict: 201 | cur_len = sum(len(conv['value'].split()) for conv in sample['conversations']) 202 | cur_len = cur_len if 'image' in sample else -cur_len 203 | length_list.append(cur_len) 204 | return length_list 205 | 206 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 207 | sources = self.list_data_dict[i] 208 | if isinstance(i, int): 209 | sources = [sources] 210 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 211 | if 'image' in sources[0]: 212 | image_file = self.list_data_dict[i]['image'] 213 | image_folder = self.data_args.image_folder 214 | processor = self.data_args.image_processor 215 | image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') 216 | if self.data_args.image_aspect_ratio == 'pad': 217 | def expand2square(pil_img, background_color): 218 | width, height = pil_img.size 219 | if width == height: 220 | return pil_img 221 | elif width > height: 222 | result = Image.new(pil_img.mode, (width, width), background_color) 223 | result.paste(pil_img, (0, (width - height) // 2)) 224 | return result 225 | else: 226 | result = Image.new(pil_img.mode, (height, height), background_color) 227 | result.paste(pil_img, ((height - width) // 2, 0)) 228 | return result 229 | 230 | image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean)) 231 | image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 232 | else: 233 | image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 234 | sources = preprocess_multimodal( 235 | copy.deepcopy([e["conversations"] for e in sources]), self.data_args) 236 | else: 237 | sources = copy.deepcopy([e["conversations"] for e in sources]) 238 | data_dict = preprocess( 239 | sources, 240 | self.tokenizer, 241 | has_image=('image' in self.list_data_dict[i])) 242 | if isinstance(i, int): 243 | data_dict = dict(input_ids=data_dict["input_ids"][0], 244 | labels=data_dict["labels"][0]) 245 | 246 | # image exist in the data 247 | if 'image' in self.list_data_dict[i]: 248 | data_dict['image'] = image 249 | elif self.data_args.is_multimodal: 250 | # image does not exist in the data, but the model is multimodal 251 | crop_size = self.data_args.image_processor.crop_size 252 | data_dict['image'] = torch.zeros(3, crop_size['height'], crop_size['width']) 253 | return data_dict 254 | 255 | 256 | @dataclass 257 | class DataCollatorForSupervisedDataset(object): 258 | """Collate examples for supervised fine-tuning.""" 259 | 260 | tokenizer: transformers.PreTrainedTokenizer 261 | 262 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 263 | input_ids, labels = tuple([instance[key] for instance in instances] 264 | for key in ("input_ids", "labels")) 265 | 266 | if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: 267 | for input_id in input_ids: 268 | input_id[input_id == self.tokenizer.eos_token_id] = -300 269 | 270 | input_ids = torch.nn.utils.rnn.pad_sequence( 271 | input_ids, 272 | batch_first=True, 273 | padding_value=self.tokenizer.pad_token_id) 274 | 275 | labels = torch.nn.utils.rnn.pad_sequence( 276 | labels, 277 | batch_first=True, 278 | padding_value=IGNORE_INDEX) 279 | 280 | input_ids = input_ids[:, :self.tokenizer.model_max_length] 281 | 282 | attention_mask = input_ids.ne(self.tokenizer.pad_token_id) 283 | 284 | labels = labels[:, :self.tokenizer.model_max_length] 285 | 286 | if self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: 287 | for input_id in input_ids: 288 | input_id[input_id == -300] = self.tokenizer.eos_token_id 289 | 290 | batch = dict( 291 | input_ids=input_ids, 292 | labels=labels, 293 | attention_mask=attention_mask, 294 | ) 295 | 296 | if 'image' in instances[0]: 297 | images = [instance['image'] for instance in instances] 298 | if all(x is not None and x.shape == images[0].shape for x in images): 299 | batch['images'] = torch.stack(images) 300 | else: 301 | batch['images'] = images 302 | 303 | return batch 304 | 305 | 306 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, 307 | data_args) -> Dict: 308 | """Make dataset and collator for supervised fine-tuning.""" 309 | train_dataset = LazySupervisedDataset(tokenizer=tokenizer, 310 | data_path=data_args.data_path, 311 | data_args=data_args) 312 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 313 | return dict(train_dataset=train_dataset, 314 | eval_dataset=None, 315 | data_collator=data_collator) 316 | -------------------------------------------------------------------------------- /cerule/util/mm_utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import torch 3 | 4 | from PIL import Image 5 | from io import BytesIO 6 | from transformers import StoppingCriteria 7 | 8 | from cerule.constants import IMAGE_TOKEN_INDEX 9 | 10 | 11 | def load_image_from_base64(image): 12 | return Image.open(BytesIO(base64.b64decode(image))) 13 | 14 | 15 | def expand2square(pil_img, background_color): 16 | width, height = pil_img.size 17 | if width == height: 18 | return pil_img 19 | elif width > height: 20 | result = Image.new(pil_img.mode, (width, width), background_color) 21 | result.paste(pil_img, (0, (width - height) // 2)) 22 | return result 23 | else: 24 | result = Image.new(pil_img.mode, (height, height), background_color) 25 | result.paste(pil_img, ((height - width) // 2, 0)) 26 | return result 27 | 28 | 29 | def process_images(images, image_processor, model_cfg): 30 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 31 | new_images = [] 32 | if image_aspect_ratio == 'pad': 33 | for image in images: 34 | image = expand2square(image, tuple(int(x * 255) for x in image_processor.image_mean)) 35 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 36 | new_images.append(image) 37 | else: 38 | return image_processor(images, return_tensors='pt')['pixel_values'] 39 | if all(x.shape == new_images[0].shape for x in new_images): 40 | new_images = torch.stack(new_images, dim=0) 41 | return new_images 42 | 43 | 44 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 45 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 46 | 47 | def insert_separator(X, sep): 48 | return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] 49 | 50 | input_ids = [] 51 | offset = 0 52 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 53 | offset = 1 54 | input_ids.append(prompt_chunks[0][0]) 55 | 56 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 57 | input_ids.extend(x[offset:]) 58 | 59 | if return_tensors is not None: 60 | if return_tensors == 'pt': 61 | return torch.tensor(input_ids, dtype=torch.long) 62 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 63 | return input_ids 64 | 65 | 66 | def get_model_name_from_path(model_path): 67 | model_path = model_path.strip("/") 68 | model_paths = model_path.split("/") 69 | if model_paths[-1].startswith('checkpoint-'): 70 | return model_paths[-2] + "_" + model_paths[-1] 71 | else: 72 | return model_paths[-1] 73 | 74 | 75 | class KeywordsStoppingCriteria(StoppingCriteria): 76 | def __init__(self, keywords, tokenizer, input_ids): 77 | self.keywords = keywords 78 | self.keyword_ids = [] 79 | self.max_keyword_len = 0 80 | for keyword in keywords: 81 | cur_keyword_ids = tokenizer(keyword).input_ids 82 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 83 | cur_keyword_ids = cur_keyword_ids[1:] 84 | if len(cur_keyword_ids) > self.max_keyword_len: 85 | self.max_keyword_len = len(cur_keyword_ids) 86 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 87 | self.tokenizer = tokenizer 88 | self.start_len = input_ids.shape[1] 89 | 90 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 91 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 92 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 93 | for keyword_id in self.keyword_ids: 94 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] 95 | if torch.equal(truncated_output_ids, keyword_id): 96 | return True 97 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 98 | for keyword in self.keywords: 99 | if keyword in outputs: 100 | return True 101 | return False 102 | 103 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 104 | outputs = [] 105 | for i in range(output_ids.shape[0]): 106 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 107 | return all(outputs) 108 | -------------------------------------------------------------------------------- /cerule/util/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import logging.handlers 3 | import os 4 | import sys 5 | 6 | from cerule.constants import LOGDIR 7 | 8 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 9 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 10 | 11 | handler = None 12 | 13 | 14 | def disable_torch_init(): 15 | """ 16 | Disable the redundant torch default initialization to accelerate model creation. 17 | """ 18 | import torch 19 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 20 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 21 | 22 | 23 | def build_logger(logger_name, logger_filename): 24 | global handler 25 | 26 | formatter = logging.Formatter( 27 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 28 | datefmt="%Y-%m-%d %H:%M:%S", 29 | ) 30 | 31 | # Set the format of root handlers 32 | if not logging.getLogger().handlers: 33 | logging.basicConfig(level=logging.INFO) 34 | logging.getLogger().handlers[0].setFormatter(formatter) 35 | 36 | # Redirect stdout and stderr to loggers 37 | stdout_logger = logging.getLogger("stdout") 38 | stdout_logger.setLevel(logging.INFO) 39 | sl = StreamToLogger(stdout_logger, logging.INFO) 40 | sys.stdout = sl 41 | 42 | stderr_logger = logging.getLogger("stderr") 43 | stderr_logger.setLevel(logging.ERROR) 44 | sl = StreamToLogger(stderr_logger, logging.ERROR) 45 | sys.stderr = sl 46 | 47 | # Get logger 48 | logger = logging.getLogger(logger_name) 49 | logger.setLevel(logging.INFO) 50 | 51 | # Add a file handler for all loggers 52 | if handler is None: 53 | os.makedirs(LOGDIR, exist_ok=True) 54 | filename = os.path.join(LOGDIR, logger_filename) 55 | handler = logging.handlers.TimedRotatingFileHandler( 56 | filename, when='D', utc=True, encoding='UTF-8') 57 | handler.setFormatter(formatter) 58 | 59 | for name, item in logging.root.manager.loggerDict.items(): 60 | if isinstance(item, logging.Logger): 61 | item.addHandler(handler) 62 | 63 | return logger 64 | 65 | 66 | class StreamToLogger(object): 67 | """ 68 | Fake file-like stream object that redirects writes to a logger instance. 69 | """ 70 | 71 | def __init__(self, logger, log_level=logging.INFO): 72 | self.terminal = sys.stdout 73 | self.logger = logger 74 | self.log_level = log_level 75 | self.linebuf = '' 76 | 77 | def __getattr__(self, attr): 78 | return getattr(self.terminal, attr) 79 | 80 | def write(self, buf): 81 | temp_linebuf = self.linebuf + buf 82 | self.linebuf = '' 83 | for line in temp_linebuf.splitlines(True): 84 | # From the io.TextIOWrapper docs: 85 | # On output, if newline is None, any '\n' characters written 86 | # are translated to the system default line separator. 87 | # By default sys.stdout.write() expects '\n' newlines and then 88 | # translates them so this is still cross platform. 89 | if line[-1] == '\n': 90 | self.logger.log(self.log_level, line.rstrip()) 91 | else: 92 | self.linebuf += line 93 | 94 | def flush(self): 95 | if self.linebuf != '': 96 | self.logger.log(self.log_level, self.linebuf.rstrip()) 97 | self.linebuf = '' 98 | 99 | 100 | def violates_moderation(text): 101 | """ 102 | Check whether the text violates OpenAI moderation API. 103 | """ 104 | url = "https://api.openai.com/v1/moderations" 105 | headers = {"Content-Type": "application/json", 106 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 107 | text = text.replace("\n", "") 108 | data = "{" + '"input": ' + f'"{text}"' + "}" 109 | data = data.encode("utf-8") 110 | try: 111 | ret = requests.post(url, headers=headers, data=data, timeout=5) 112 | flagged = ret.json()["results"][0]["flagged"] 113 | except requests.exceptions.RequestException as e: 114 | flagged = False 115 | except KeyError as e: 116 | flagged = False 117 | 118 | return flagged 119 | 120 | 121 | def pretty_print_semaphore(semaphore): 122 | if semaphore is None: 123 | return "None" 124 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 125 | -------------------------------------------------------------------------------- /examples/YHyRn8r.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/examples/YHyRn8r.png -------------------------------------------------------------------------------- /examples/astronaut.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/examples/astronaut.png -------------------------------------------------------------------------------- /examples/bridge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/examples/bridge.png -------------------------------------------------------------------------------- /examples/design.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/examples/design.jpg -------------------------------------------------------------------------------- /examples/extreme_ironing.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/examples/extreme_ironing.jpg -------------------------------------------------------------------------------- /examples/google.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/examples/google.png -------------------------------------------------------------------------------- /examples/graph.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/examples/graph.jpg -------------------------------------------------------------------------------- /examples/graph1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/examples/graph1.jpg -------------------------------------------------------------------------------- /examples/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/examples/image.png -------------------------------------------------------------------------------- /examples/mario.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/examples/mario.png -------------------------------------------------------------------------------- /examples/sting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tensoic/Cerule/29ae3aa0d55fd76f8db2e787f9e23b6ceb9d756c/examples/sting.png -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "cerule" 7 | version = "1.0" 8 | description = "Cerule - Tiny Mighty Vision Models" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | 'accelerate', 'apex', 'bitsandbytes', 'datasets', 'deepspeed', 'einops', 'einops-exts', 17 | 'fastapi', 'flash_attn', 'gradio', 'gradio_client', 'httpx', 'markdown2', 'numpy', 'openpyxl', 18 | 'peft', 'protobuf', 'pydantic', 'pypandoc', 'requests', 'scikit-learn', 'sentencepiece', 'shortuuid', 19 | 'timm', 'tiktoken', 'tokenizers', 'torch', 'torchvision', 'transformers', 'uvicorn', 'xformers' 20 | ] 21 | 22 | 23 | [project.urls] 24 | "Homepage" = "https://github.com/tensoic/Cerule" 25 | "Discussion" = "https://github.com/tensoic/Cerule/issues" 26 | 27 | [tool.setuptools.packages.find] 28 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 29 | 30 | [tool.wheel] 31 | exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"] 32 | -------------------------------------------------------------------------------- /script/deepspeed/scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } 57 | -------------------------------------------------------------------------------- /script/deepspeed/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /script/deepspeed/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /script/merge_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from cerule.model.builder import load_pretrained_model 4 | from cerule.util.mm_utils import get_model_name_from_path 5 | 6 | 7 | def merge_lora(args): 8 | model_path = os.path.expanduser(args.model_path) 9 | model_name = get_model_name_from_path(model_path) 10 | tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, args.model_base, model_name, 11 | args.model_type) 12 | 13 | model.save_pretrained(args.save_model_path) 14 | tokenizer.save_pretrained(args.save_model_path) 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--model-path", type=str, required=True) 20 | parser.add_argument("--model-base", type=str, required=True) 21 | parser.add_argument("--model-type", type=str, required=True) 22 | parser.add_argument("--save-model-path", type=str, required=True) 23 | 24 | args = parser.parse_args() 25 | 26 | merge_lora(args) 27 | -------------------------------------------------------------------------------- /script/train/finetune_full.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL_TYPE=gemma 3 | 4 | PRETRAIN_DIR=cerule-$MODEL_TYPE-pretrain 5 | OUTPUT_DIR=cerule-$MODEL_TYPE 6 | mkdir -p ./checkpoints-$MODEL_TYPE/$OUTPUT_DIR 7 | 8 | deepspeed cerule/train/train.py \ 9 | --deepspeed ./script/deepspeed/zero3.json \ 10 | --model_name_or_path google/gemma-2b \ 11 | --model_type $MODEL_TYPE \ 12 | --version gemma_instruct \ 13 | --data_path /finetune/bunny_695k.json \ 14 | --image_folder /finetune/images \ 15 | --vision_tower google/siglip-so400m-patch14-384 \ 16 | --pretrain_mm_mlp_adapter path/to/mm_projector.bin \ 17 | --mm_projector_type mlp2x_gelu \ 18 | --image_aspect_ratio pad \ 19 | --mm_vision_select_layer -2 \ 20 | --group_by_modality_length True \ 21 | --mm_use_im_start_end False \ 22 | --mm_use_im_patch_token False \ 23 | --bf16 True \ 24 | --output_dir ./checkpoints-$MODEL_TYPE/$OUTPUT_DIR \ 25 | --num_train_epochs 1 \ 26 | --per_device_train_batch_size 8 \ 27 | --per_device_eval_batch_size 4 \ 28 | --gradient_accumulation_steps 2 \ 29 | --evaluation_strategy "no" \ 30 | --save_strategy "steps" \ 31 | --save_steps 500 \ 32 | --save_total_limit 1 \ 33 | --learning_rate 2e-5 \ 34 | --weight_decay 0. \ 35 | --warmup_ratio 0.03 \ 36 | --lr_scheduler_type "cosine" \ 37 | --logging_steps 1 \ 38 | --tf32 True \ 39 | --model_max_length 2048 \ 40 | --gradient_checkpointing True \ 41 | --dataloader_num_workers 4 \ 42 | --lazy_preprocess True \ 43 | --report_to wandb | tee 2>&1 ./checkpoints-$MODEL_TYPE/$OUTPUT_DIR/log.txt 44 | -------------------------------------------------------------------------------- /script/train/finetune_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | MODEL_TYPE=gemma 3 | 4 | PRETRAIN_DIR=cerule-$MODEL_TYPE-pretrain 5 | OUTPUT_DIR=cerule-lora-$MODEL_TYPE 6 | mkdir -p ./checkpoints-$MODEL_TYPE/$OUTPUT_DIR 7 | 8 | deepspeed cerule/train/train.py \ 9 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 10 | --deepspeed ./script/deepspeed/zero3.json \ 11 | --model_name_or_path /path/to/base_llm_model \ 12 | --model_type $MODEL_TYPE \ 13 | --version gemma_instruct \ 14 | --data_path /finetune/bunny_695k.json \ 15 | --image_folder /finetune/images \ 16 | --vision_tower google/siglip-so400m-patch14-384 \ 17 | --pretrain_mm_mlp_adapter path/to/mm_projector.bin \ 18 | --mm_projector_type mlp2x_gelu \ 19 | --image_aspect_ratio pad \ 20 | --group_by_modality_length False \ 21 | --bf16 True \ 22 | --output_dir ./checkpoints-$MODEL_TYPE/$OUTPUT_DIR \ 23 | --num_train_epochs 1 \ 24 | --per_device_train_batch_size 8 \ 25 | --per_device_eval_batch_size 4 \ 26 | --gradient_accumulation_steps 2 \ 27 | --evaluation_strategy "no" \ 28 | --save_strategy "steps" \ 29 | --save_steps 500 \ 30 | --save_total_limit 1 \ 31 | --learning_rate 2e-4 \ 32 | --weight_decay 0. \ 33 | --warmup_ratio 0.03 \ 34 | --lr_scheduler_type "cosine" \ 35 | --logging_steps 1 \ 36 | --tf32 True \ 37 | --model_max_length 2048 \ 38 | --gradient_checkpointing True \ 39 | --dataloader_num_workers 4 \ 40 | --lazy_preprocess True \ 41 | --report_to none | tee 2>&1 ./checkpoints-$MODEL_TYPE/$OUTPUT_DIR/log.txt -------------------------------------------------------------------------------- /script/train/pretrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | MODEL_TYPE=gemma 4 | OUTPUT_DIR=cerule-$MODEL_TYPE-pretrain 5 | 6 | mkdir -p ./checkpoints-pretrain/$OUTPUT_DIR 7 | 8 | deepspeed cerule/train/train.py \ 9 | --deepspeed ./script/deepspeed/zero2.json \ 10 | --model_name_or_path google/gemma-2b \ 11 | --model_type $MODEL_TYPE \ 12 | --version plain \ 13 | --data_path /pretrain/blip_laion_cc_sbu_558k.json \ 14 | --image_folder /pretrain/images \ 15 | --vision_tower google/siglip-so400m-patch14-384 \ 16 | --mm_projector_type mlp2x_gelu \ 17 | --tune_mm_mlp_adapter True \ 18 | --image_aspect_ratio square \ 19 | --bf16 True \ 20 | --output_dir ./checkpoints-pretrain/$OUTPUT_DIR \ 21 | --num_train_epochs 1 \ 22 | --per_device_train_batch_size 8 \ 23 | --per_device_eval_batch_size 4 \ 24 | --gradient_accumulation_steps 4 \ 25 | --evaluation_strategy "no" \ 26 | --save_strategy "steps" \ 27 | --save_steps 24000 \ 28 | --save_total_limit 1 \ 29 | --learning_rate 5e-4 \ 30 | --weight_decay 0. \ 31 | --warmup_ratio 0.03 \ 32 | --lr_scheduler_type "cosine" \ 33 | --logging_steps 1 \ 34 | --tf32 True \ 35 | --model_max_length 2048 \ 36 | --gradient_checkpointing True \ 37 | --dataloader_num_workers 4 \ 38 | --lazy_preprocess True \ 39 | --report_to wandb | tee 2>&1 ./checkpoints-pretrain/$OUTPUT_DIR/log.txt 40 | --------------------------------------------------------------------------------