├── DATA.md ├── LICENSE ├── README.md ├── cog.yaml ├── inference └── infer.py ├── ola ├── arguments.py ├── constants.py ├── conversation.py ├── datasets │ ├── __init__.py │ └── preprocess.py ├── mm_utils.py ├── model │ ├── __init__.py │ ├── builder.py │ ├── language_model │ │ └── ola_qwen.py │ ├── multimodal_encoder │ │ ├── builder.py │ │ └── oryx_vit.py │ ├── multimodal_projector │ │ ├── builder.py │ │ └── pooler_projector.py │ ├── multimodal_resampler │ │ ├── __pycache__ │ │ │ ├── builder.cpython-310.pyc │ │ │ ├── builder.cpython-38.pyc │ │ │ ├── perceiver.cpython-310.pyc │ │ │ └── perceiver.cpython-38.pyc │ │ └── builder.py │ ├── ola_arch.py │ ├── speech_encoder │ │ ├── beats │ │ │ ├── BEATs.py │ │ │ ├── Tokenizers.py │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ │ ├── BEATs.cpython-310.pyc │ │ │ │ ├── BEATs.cpython-38.pyc │ │ │ │ ├── __init__.cpython-310.pyc │ │ │ │ ├── __init__.cpython-38.pyc │ │ │ │ ├── backbone.cpython-310.pyc │ │ │ │ ├── backbone.cpython-38.pyc │ │ │ │ ├── kaldi.cpython-310.pyc │ │ │ │ ├── kaldi.cpython-38.pyc │ │ │ │ ├── modules.cpython-310.pyc │ │ │ │ └── modules.cpython-38.pyc │ │ │ ├── backbone.py │ │ │ ├── kaldi.py │ │ │ ├── modules.py │ │ │ └── quantizer.py │ │ ├── builder.py │ │ └── speech_encoder.py │ └── speech_projector │ │ ├── builder.py │ │ └── speech_projector.py ├── serve │ ├── __init__.py │ ├── controller.py │ ├── gradio_web_server.py │ └── model_worker.py ├── train │ ├── ola_trainer.py │ └── train.py └── utils.py ├── pyproject.toml ├── scripts ├── .DS_Store ├── finetune_ola.sh ├── finetune_ola_image.sh ├── finetune_ola_video.sh ├── zero2.json └── zero3.json └── tools ├── convert_mp4_wav.py └── create_patch.py /DATA.md: -------------------------------------------------------------------------------- 1 | # Data 2 | 3 | ### Data Format 4 | 5 | We follow the data format below, which is similar to LLaVA. You can directly use the original file path or pack the multi-modal files into patches following [create_patch.py](https://github.com/Ola-Omni/Ola/blob/main/tools/create_patch.py). Patch is a binary file containing continuous image or video files in byte format, which may accelerate reading speed in some cases. 6 | 7 | 8 | - Image Data: 9 | 10 | ``` 11 | [ 12 | { 13 | 'id': ID of the data 14 | 'image': ***.png (path to the image file or positions in patches) 15 | 'conversations': [{"from": "human", "value": "\n"}, {"from": "gpt", "value": ""}] 16 | } 17 | ] 18 | ``` 19 | 20 | The format for image patch is: 21 | 22 | ``` 23 | { 24 | "patch": "patch_00000", 25 | "start_num": 846113989, 26 | "size": 27141 27 | } 28 | ``` 29 | 30 | - Video Frame Data: 31 | 32 | ``` 33 | [ 34 | { 35 | 'id': ID of the data 36 | 'video': ***.mp4 (path to the video file or positions in patches) 37 | 'conversations': [{"from": "human", "value": "\n"}, {"from": "gpt", "value": ""}] 38 | } 39 | ] 40 | ``` 41 | 42 | The format for video patch is: 43 | 44 | ``` 45 | { 46 | "patch": "patch_000000", 47 | "size": [ 5605, 8902, 7917, 5562, 9249, 8785, 8379, 10389, 10505, 10337, 8481, 8164, 5562, 8844, 10565, 8035, 7768, 8969, 5643, 10478, 7632, 10980, 9986, 3602, 2848, 7591, 10766, 7813, 5605, 9840, 9664, 5605, 7726, 4828, 8006, 5562, 9711, 7903, 9542, 10626, 8827, 11268, 11115, 1832, 11354, 9222, 3965, 10426, 10427, 7311, 9726, 7655, 10025, 5350, 10098, 10470, 4877, 10273, 9730, 10150, 5604, 7203, 9881, 2246, 11114, 3790, 5567, 10490, 4072, 1701], 48 | "start_num": 26608266 49 | } 50 | ``` 51 | 52 | - Video + Audio Data: 53 | 54 | ``` 55 | [ 56 | { 57 | 'id': ID of the data 58 | 'video': ***.mp4 (path to the video file or positions in patches) 59 | 'audio': ***.wav (path to the audio file) 60 | 'conversations': [{"from": "human", "value": "\n"}, {"from": "gpt", "value": ""}] 61 | } 62 | ] 63 | ``` 64 | 65 | - Image + Audio Data: 66 | 67 | ``` 68 | [ 69 | { 70 | 'id': ID of the data 71 | 'audio_q': ***.wav (path to the audio file) 72 | 'image': ***.png (path to the image file or positions in patches) 73 | 'conversations': [{"from": "human", "value": ""\nUser's question in speech: ""}, {"from": "gpt", "value": ""}] 74 | } 75 | ] 76 | ``` 77 | 78 | - Audio Data: 79 | 80 | ``` 81 | [ 82 | { 83 | 'id': ID of the data 84 | 'audio': ***.wav (path to the audio file) 85 | 'conversations': [{"from": "human", "value": "\n"}, {"from": "gpt", "value": ""}] 86 | } 87 | ] 88 | ``` 89 | 90 | ### Instruction for Ola Data 91 | 92 | **You can simply mix up the separated training jsons for joint training with image/video/audio data.** 93 | 94 | #### **Ola-Video-1.9M** 95 | 96 | 1. Download [Ola-video-1.9M.json](https://huggingface.co/datasets/THUdyh/Ola-Data/blob/main/video_data/video-data.json) from huggingface. 97 | 98 | 2. Download all the [video patches](https://huggingface.co/datasets/THUdyh/Ola-Data/tree/main/video_data) from huggingface. 99 | 100 | 3. Check and modify the video patch path in the json to the true path in your machine. 101 | 102 | #### **Ola-Audio-1.1M** 103 | 104 | 1. Download [Ola_audio_1169k.json](https://huggingface.co/datasets/THUdyh/Ola-Data/blob/main/Ola_audio_1169k.json) from huggingface. 105 | 106 | 2. Download [wav tar file](https://huggingface.co/datasets/THUdyh/Ola-Data/tree/main/ola_audio) from huggingface and unzip all the files. 107 | 108 | 3. Check the file structure: 109 | 110 | ``` 111 | │ola_audio/ 112 | ├── Ola_audio_1169k.json 113 | ├── AudioCaps/ 114 | ├── Clotho/ 115 | ├── GigaSpeech/ 116 | ├── LibriSpeech/ 117 | ├── MillionSongDatasetSpotify/ 118 | ├── MusicCaps/ 119 | ├── WavCaps/ 120 | ``` 121 | 122 | 4. Check and modify the audio file path in the json to the true path in your machine. 123 | 124 | #### **Ola-Cross-Modality-298k** 125 | 126 | 1. Download [Ola_cross_modality_finevideo_175k.json](https://huggingface.co/datasets/THUdyh/Ola-Data/blob/main/Ola_cross_modality_finevideo_175k.json) and [Ola_cross_modality_llava_123k.json](https://huggingface.co/datasets/THUdyh/Ola-Data/blob/main/Ola_cross_modality_llava_123k.json) from huggingface. 127 | 128 | 2. Download [FineVideo](https://huggingface.co/datasets/HuggingFaceFV/finevideo/tree/main) from huggingface. 129 | 130 | 3. Download [LLaVA-Video-178k](https://huggingface.co/datasets/lmms-lab/LLaVA-Video-178K/tree/main) from huggingface. 131 | 132 | 4. Extract pure video from FineVideo and LLaVA-Video-178k. 133 | 134 | 5. Transfer and save the wav file of the videos using [convert_mp4_wav.py](https://github.com/Ola-Omni/Ola/blob/main/tools/convert_mp4_wav.py). 135 | 136 | 6. Check the file structure: 137 | 138 | ``` 139 | │ola_cross_modality_298k/ 140 | ├── Ola_cross_modality_finevideo_175k.json 141 | ├── Ola_cross_modality_llava_123k.json 142 | ├── finevideo_audios/ 143 | │ ├── lltmlYR56dI.wav 144 | │ ├── ...... 145 | ├── finevideo_videos/ 146 | │ ├── lltmlYR56dI.mp4 147 | │ ├── ...... 148 | ├── llava_audios/ 149 | │ ├── academic_source 150 | │ ├── ActivityNet-QA 151 | │ ├── liwei_youtube_videos 152 | │ ├── NextQA 153 | │ ├── perception_test 154 | ├── llava_videos/ 155 | │ ├── academic_source 156 | │ ├── ActivityNet-QA 157 | │ ├── liwei_youtube_videos 158 | │ ├── NextQA 159 | │ ├── perception_test 160 | ``` 161 | 162 | 7. Check and modify the video and audio path in the json to the true path in your machine. 163 | 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 967023137dff29e65b21544e7620e0f7.webp 3 |

4 |
5 | 6 | ## Ola: Pushing the Frontiers of Omni-Modal Language Model 7 | 8 |

9 | Zuyan Liu*,1,2  10 | Yuhao Dong*,2,3  11 | Jiahui Wang1
12 | Ziwei Liu3  13 | Winston Hu2  14 | Jiwen Lu1,✉  15 | Yongming Rao2,1,✉  16 |

17 | 18 | 19 |

1Tsinghua University   2Tencent Hunyuan Research  3S-Lab, NTU 

20 | 21 |

* Equal Contribution  ✉ Corresponding Author

22 | 23 | [![Ola](https://img.shields.io/badge/Rank_1-OpenCampass(<15B)-blue)](https://rank.opencompass.org.cn/leaderboard-multimodal/?m=REALTIME) [![Ola](https://img.shields.io/badge/Rank_8-VideoMME-red)](https://video-mme.github.io/home_page.html#leaderboard) 24 | 25 | --- 26 | 27 | **Project Page:** [![Ola](https://img.shields.io/badge/Ola-project_page-orange)](https://ola-omni.github.io) 28 | 29 | **Weights in Huggingface:** [![hf_checkpoint](https://img.shields.io/badge/🤗-Ola_7b-green)](https://huggingface.co/THUdyh/Ola-7b) [![hf_checkpoint](https://img.shields.io/badge/🤗-Ola_Image-green)](https://huggingface.co/THUdyh/Ola-Image) [![hf_checkpoint](https://img.shields.io/badge/🤗-Ola_Video-green)](https://huggingface.co/THUdyh/Ola-Video) 30 | 31 | **arXiv Paper:** [![arxiv](https://img.shields.io/badge/Arxiv-2502.04328-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2502.04328) 32 | 33 | **Demo by Gradio:** [![demo](https://img.shields.io/badge/Ola-Demo-yellow)](https://huggingface.co/spaces/THUdyh/Ola) 34 | 35 | **Training Data:** [![data](https://img.shields.io/badge/Ola-Data-purple)](https://huggingface.co/datasets/THUdyh/Ola-Data) 36 | 37 | **中文解读**: [![chinese](https://img.shields.io/badge/Ola-机器之心-cyan)](https://mp.weixin.qq.com/s/N4bjcHOejJudtxTFZVAXmg) 38 | 39 | Contact: Leave an issue or contact liuzuyan19@gmail.com . We are on call to respond. 40 | 41 | ## 📢 News 42 | 43 | - 🔥[28/2/2025] We release the intermediate model, Ola-Image and Ola-Video, try building your own omni-modal models! 44 | 45 | - 🚀[19/2/2025] We release the huggingface demo of Ola, try the advanced omni-modal model on your own! 46 | 47 | - 🔥[18/2/2025] The training data, training script for Ola-7b is released! 48 | 49 | - 🎉[07/2/2025] The Ola is released! Check our [project page](https://ola-omni.github.io), [model weights](https://huggingface.co/THUdyh/Ola-7b), [arXiv paper](https://arxiv.org/pdf/2502.04328) for the strong omni-modal understanding model! 50 | 51 | - 🔥[06/2/2025] [Ola-7b](https://huggingface.co/THUdyh/Ola-7b) achieves **Rank #1** on the OpenCompass Multi-modal Leaderboard among all the models under 15B parameters with average score of **72.6**. Check the impressive results [here](https://rank.opencompass.org.cn/leaderboard-multimodal/?m=REALTIME)! 52 | 53 | ## 🚀Coming Soon 54 | 55 | - [x] Evaluation code on omni-modal benchmarks 56 | - [x] Gradio Demo 57 | - [x] Training Data (Video, Audio, Cross-Modality) 58 | 59 | ## 🌟 Introduction 60 | 61 | ### Roads to Ola 62 | 63 |

64 | road.png 65 |

66 |
67 | 68 | **Ola** is an Omni-modal language model that achieves competitive performance across image, video, and audio understanding compared to specialized counterparts. We conduct a comprehensive exploration of architectural design, data curation, and training strategies essential for building a robust omni-modal model. 69 | 70 |

71 | teaser.png 72 |

73 |
74 | 75 | ### Architecture 76 | 77 |

78 | method.png 79 |

80 |
81 | 82 | Ola supports omni-modal inputs including text, image, video, and audio, capable of processing the inputs simultaneously with competitive performance on understanding tasks for all these modalities. Meanwhile, Ola supports user-friendly real-time streaming decoding for texts and speeches thanks to the text detokenizer and the speech decoder. 83 | 84 | ### Training Strategies 85 | 86 |

87 | training.png 88 |

89 |
90 | 91 | We visualize the relationships among modalities in the left part. Speech acts as the connection between language and audio knowledge, while video constructs the bridge with highly relevant visual and audio information. Therefore, we design the progressive alignment training strategy from primary to periphery. Furthermore, we design the cross-modality video-audio data to better capture the relationships among modalities. 92 | 93 | ### Performance 94 | 95 |

96 | results.png 97 |

98 |
99 | 100 | Ola achieves competitive performance across major multi-modal benchmarks when compared to state-of-the-art specialist-modal LLMs. 101 | 102 | ## Installation 103 | 104 | 105 | #### 1. Clone this repository: 106 | ```bash 107 | git clone https://github.com/Ola-Omni/Ola 108 | cd Ola 109 | ``` 110 | 111 | #### 2. Install the required package: 112 | ```bash 113 | conda create -n ola python=3.10 -y 114 | conda activate ola 115 | pip install --upgrade pip 116 | pip install -e . 117 | ``` 118 | #### 3.Install additional packages for training cases 119 | 120 | ```bash 121 | pip install -e ".[train]" 122 | pip install flash-attn --no-build-isolation 123 | ``` 124 | 125 | ## Model Zoo 126 | 127 | We provide our checkpoints at [Huggingface](https://huggingface.co/collections/THUdyh/ola-67b8220eb93406ec87aeec37) 128 | 129 | | Model | Link | Size | Modal | 130 | |:---:|:---:|:---:|:---:| 131 | |Ola-7b | [Huggingface](https://huggingface.co/THUdyh/Ola-7b) | 7B | Text, Image, Video, Audio | 132 | |Ola-Image | [Huggingface](https://huggingface.co/THUdyh/Ola-Image) | 7B | Text, Image | 133 | |Ola-Video | [Huggingface](https://huggingface.co/THUdyh/Ola-Video) | 7B | Text, Image, Video | 134 | 135 | 136 | ## Quick Start 137 | 138 | 1. Download `Ola-7b` from [Huggingface](https://huggingface.co/THUdyh/Ola-7b) or skip the step to using the online weights directly. 139 | 140 | 2. Download audio encoder from [Huggingface](https://huggingface.co/THUdyh/Ola_speech_encoders/tree/main) and put the weights `large-v3.pt` and `BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt` under repo directory `path/to/Ola/` 141 | 142 | 3. Run `inference/infer.py` 143 | 144 | - Text & Image Understanding 145 | 146 | ``` 147 | python3 inference/infer.py --image_path *.png,jpg --text user_instruction 148 | ``` 149 | 150 | - Text & Video Understanding 151 | 152 | ``` 153 | python3 inference/infer.py --video_path *.mp4 --text user_instruction 154 | ``` 155 | 156 | - Text & Audio Understanding 157 | 158 | ``` 159 | python3 inference/infer.py --audio_path *.wav,mp3 --text user_instruction 160 | ``` 161 | 162 | - Audio & Image Understanding 163 | 164 | ``` 165 | python3 inference/infer.py --audio_path *.png,jpg --audio_path *.wav,mp3 166 | ``` 167 | 168 | ## Evaluation 169 | 170 | You can evaluate Ola model with [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) and [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval). 171 | 172 | ## Training 173 | 174 | ### Data Preparation 175 | 176 | Please refer to [DATA.md](https://github.com/Ola-Omni/Ola/blob/main/DATA.md) for instructions of customized finetuning or using the provided datasets. 177 | 178 | ### Start Training 179 | 180 | Please follow the script below to start training. Make sure you have created the correct datasets for fine-tuning. 181 | 182 | 1. Finetuning Ola-7b Model: 183 | 184 | ``` 185 | bash ./scripts/finetune_ola.sh 186 | ``` 187 | 188 | 2. Finetuning Ola-Image Model (Ola Stage1 or Stage2) 189 | 190 | ``` 191 | bash ./scripts/finetune_ola_image.sh 192 | ``` 193 | 194 | 3. Finetuning Ola-Video Model (Ola Stage3): 195 | 196 | ``` 197 | bash ./scripts/finetune_ola_video.sh 198 | ``` 199 | 200 | ## Citation 201 | 202 | If you find it useful for your research and applications, please cite our paper using this BibTeX: 203 | ```bibtex 204 | @article{liu2025ola, 205 | title={Ola: Pushing the Frontiers of Omni-Modal Language Model}, 206 | author={Liu, Zuyan and Dong, Yuhao and Wang, Jiahui and Liu, Ziwei and Hu, Winston and Lu, Jiwen and Rao, Yongming}, 207 | journal={arXiv preprint arXiv:2502.04328}, 208 | year={2025} 209 | } 210 | ``` 211 | 212 | ## Acknowledgement 213 | 214 | - Our codebase is conducted on [LLaVA](https://github.com/LLaVA-VL/LLaVA-NeXT) 215 | 216 | - Thanks [VLMEvalKit](https://github.com/open-compass/VLMEvalKit) and [lmms-eval](https://github.com/EvolvingLMMs-Lab/lmms-eval) team for the evaluation system! 217 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://cog.run/yaml 3 | 4 | build: 5 | # set to true if your model requires a GPU 6 | gpu: true 7 | 8 | # a list of ubuntu apt packages to install 9 | system_packages: 10 | - "libgl1-mesa-glx" 11 | - "libglib2.0-0" 12 | 13 | # python version in the form '3.11' or '3.11.4' 14 | python_version: "3.10" 15 | 16 | # a list of packages in the format == 17 | python_packages: 18 | - "torch==2.1.2" 19 | - "torchvision==0.16.2" 20 | - "torchaudio==2.1.2" 21 | - "transformers==4.43.4" 22 | - "tokenizers==0.19.1" 23 | - "sentencepiece==0.1.99" 24 | - "shortuuid" 25 | - "accelerate==0.33.0" 26 | - "peft==0.11.1" 27 | - "bitsandbytes==0.43.1" 28 | - "pydantic<2" 29 | - "markdown2[all]" 30 | - "numpy" 31 | - "scikit-learn==1.2.2" 32 | - "gradio_client==1.3.0" 33 | - "requests" 34 | - "httpx==0.27.2" 35 | - "uvicorn" 36 | - "fastapi" 37 | - "soundfile" 38 | - "einops==0.6.1" 39 | - "einops-exts==0.0.4" 40 | - "timm==0.6.13" 41 | - "openai-whisper" 42 | - "setuptools==59.5.0" 43 | - "omegaconf==2.0.6" 44 | run: 45 | - git clone https://github.com/pytorch/fairseq && cd fairseq && pip install -e . --no-build-isolation 46 | - pip install flash-attn --no-build-isolation 47 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.8.2/pget_linux_x86_64" && chmod +x /usr/local/bin/pget 48 | 49 | # predict.py defines how predictions are run on your model 50 | predict: "predict.py:Predictor" 51 | -------------------------------------------------------------------------------- /inference/infer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['LOWRES_RESIZE'] = '384x32' 4 | os.environ['HIGHRES_BASE'] = '0x32' 5 | os.environ['VIDEO_RESIZE'] = "0x64" 6 | os.environ['VIDEO_MAXRES'] = "480" 7 | os.environ['VIDEO_MINRES'] = "288" 8 | os.environ['MAXRES'] = '1536' 9 | os.environ['MINRES'] = '0' 10 | os.environ['FORCE_NO_DOWNSAMPLE'] = '1' 11 | os.environ['LOAD_VISION_EARLY'] = '1' 12 | os.environ['PAD2STRIDE'] = '1' 13 | 14 | import gradio as gr 15 | import torch 16 | import re 17 | from decord import VideoReader, cpu 18 | from PIL import Image 19 | import numpy as np 20 | import transformers 21 | import moviepy.editor as mp 22 | from typing import Dict, Optional, Sequence, List 23 | import librosa 24 | import whisper 25 | from ola.conversation import conv_templates, SeparatorStyle 26 | from ola.model.builder import load_pretrained_model 27 | from ola.datasets.preprocess import tokenizer_image_token, tokenizer_speech_image_token, tokenizer_speech_question_image_token, tokenizer_speech_token 28 | from ola.mm_utils import KeywordsStoppingCriteria, process_anyres_video, process_anyres_highres_image 29 | from ola.constants import IGNORE_INDEX, DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN, SPEECH_TOKEN_INDEX 30 | import argparse 31 | 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--model_path', type=str, default='THUdyh/Ola-7b') 34 | parser.add_argument('--text', type=str, default=None) 35 | parser.add_argument('--audio_path', type=str, default=None) 36 | parser.add_argument('--image_path', type=str, default=None) 37 | parser.add_argument('--video_path', type=str, default=None) 38 | args = parser.parse_args() 39 | 40 | model_path = args.model_path 41 | tokenizer, model, image_processor, _ = load_pretrained_model(model_path, None) 42 | model = model.to('cuda').eval() 43 | model = model.bfloat16() 44 | 45 | USE_SPEECH=False 46 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 47 | 48 | def load_audio(audio_file_name): 49 | speech_wav, samplerate = librosa.load(audio_file_name, sr=16000) 50 | if len(speech_wav.shape) > 1: 51 | speech_wav = speech_wav[:, 0] 52 | speech_wav = speech_wav.astype(np.float32) 53 | CHUNK_LIM = 480000 54 | SAMPLE_RATE = 16000 55 | speechs = [] 56 | speech_wavs = [] 57 | 58 | if len(speech_wav) <= CHUNK_LIM: 59 | speech = whisper.pad_or_trim(speech_wav) 60 | speech_wav = whisper.pad_or_trim(speech_wav) 61 | speechs.append(speech) 62 | speech_wavs.append(torch.from_numpy(speech_wav).unsqueeze(0)) 63 | else: 64 | for i in range(0, len(speech_wav), CHUNK_LIM): 65 | chunk = speech_wav[i : i + CHUNK_LIM] 66 | if len(chunk) < CHUNK_LIM: 67 | chunk = whisper.pad_or_trim(chunk) 68 | speechs.append(chunk) 69 | speech_wavs.append(torch.from_numpy(chunk).unsqueeze(0)) 70 | mels = [] 71 | for chunk in speechs: 72 | chunk = whisper.log_mel_spectrogram(chunk, n_mels=128).permute(1, 0).unsqueeze(0) 73 | mels.append(chunk) 74 | 75 | mels = torch.cat(mels, dim=0) 76 | speech_wavs = torch.cat(speech_wavs, dim=0) 77 | if mels.shape[0] > 25: 78 | mels = mels[:25] 79 | speech_wavs = speech_wavs[:25] 80 | 81 | speech_length = torch.LongTensor([mels.shape[1]] * mels.shape[0]) 82 | speech_chunks = torch.LongTensor([mels.shape[0]]) 83 | return mels, speech_length, speech_chunks, speech_wavs 84 | 85 | def extract_audio(videos_file_path): 86 | my_clip = mp.VideoFileClip(videos_file_path) 87 | return my_clip.audio 88 | 89 | image_path = args.image_path 90 | audio_path = args.audio_path 91 | video_path = args.video_path 92 | text = args.text 93 | 94 | if video_path is not None: 95 | modality = "video" 96 | visual = video_path 97 | assert image_path is None 98 | 99 | elif image_path is not None: 100 | visual = image_path 101 | modality = "image" 102 | assert video_path is None 103 | 104 | elif audio_path is not None: 105 | modality = "text" 106 | 107 | 108 | # input audio and video, do not parse audio in the video, else parse audio in the video 109 | if audio_path: 110 | USE_SPEECH = True 111 | elif modality == "video": 112 | USE_SPEECH = True 113 | else: 114 | USE_SPEECH = False 115 | 116 | speechs = [] 117 | speech_lengths = [] 118 | speech_wavs = [] 119 | speech_chunks = [] 120 | if modality == "video": 121 | vr = VideoReader(visual, ctx=cpu(0)) 122 | total_frame_num = len(vr) 123 | fps = round(vr.get_avg_fps()) 124 | uniform_sampled_frames = np.linspace(0, total_frame_num - 1, 64, dtype=int) 125 | frame_idx = uniform_sampled_frames.tolist() 126 | spare_frames = vr.get_batch(frame_idx).asnumpy() 127 | video = [Image.fromarray(frame) for frame in spare_frames] 128 | elif modality == "image": 129 | image = [Image.open(visual)] 130 | image_sizes = [image[0].size] 131 | else: 132 | images = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)] 133 | images_highres = [torch.zeros(1, 3, 224, 224).to(dtype=torch.bfloat16, device='cuda', non_blocking=True)] 134 | image_sizes = [(224, 224)] 135 | 136 | 137 | if USE_SPEECH and audio_path: 138 | audio_path = audio_path 139 | speech, speech_length, speech_chunk, speech_wav = load_audio(audio_path) 140 | speechs.append(speech.bfloat16().to('cuda')) 141 | speech_lengths.append(speech_length.to('cuda')) 142 | speech_chunks.append(speech_chunk.to('cuda')) 143 | speech_wavs.append(speech_wav.to('cuda')) 144 | print('load audio') 145 | elif USE_SPEECH and not audio_path: 146 | # parse audio in the video 147 | audio = extract_audio(visual) 148 | audio.write_audiofile("./video_audio.wav") 149 | video_audio_path = './video_audio.wav' 150 | speech, speech_length, speech_chunk, speech_wav = load_audio(video_audio_path) 151 | speechs.append(speech.bfloat16().to('cuda')) 152 | speech_lengths.append(speech_length.to('cuda')) 153 | speech_chunks.append(speech_chunk.to('cuda')) 154 | speech_wavs.append(speech_wav.to('cuda')) 155 | else: 156 | speechs = [torch.zeros(1, 3000, 128).bfloat16().to('cuda')] 157 | speech_lengths = [torch.LongTensor([3000]).to('cuda')] 158 | speech_wavs = [torch.zeros([1, 480000]).to('cuda')] 159 | speech_chunks = [torch.LongTensor([1]).to('cuda')] 160 | 161 | conv_mode = "qwen_1_5" 162 | if text: 163 | qs = text 164 | else: 165 | qs = '' 166 | 167 | if USE_SPEECH and audio_path and image_path: # image + speech instruction 168 | qs = DEFAULT_IMAGE_TOKEN + "\n" + "User's question in speech: " + DEFAULT_SPEECH_TOKEN + '\n' 169 | elif USE_SPEECH and video_path: # video + audio 170 | qs = DEFAULT_SPEECH_TOKEN + DEFAULT_IMAGE_TOKEN + "\n" + qs 171 | elif USE_SPEECH and audio_path: # audio + text 172 | qs = DEFAULT_SPEECH_TOKEN + "\n" + qs 173 | elif image_path or video_path: # image / video 174 | qs = DEFAULT_IMAGE_TOKEN + "\n" + qs 175 | elif text: # text 176 | qs = qs 177 | 178 | conv = conv_templates[conv_mode].copy() 179 | conv.append_message(conv.roles[0], qs) 180 | conv.append_message(conv.roles[1], None) 181 | prompt = conv.get_prompt() 182 | if USE_SPEECH and audio_path and image_path: # image + speech instruction 183 | input_ids = tokenizer_speech_question_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') 184 | elif USE_SPEECH and video_path: # video + audio 185 | input_ids = tokenizer_speech_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') 186 | elif USE_SPEECH and audio_path: # audio + text 187 | input_ids = tokenizer_speech_token(prompt, tokenizer, SPEECH_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') 188 | else: 189 | input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to('cuda') 190 | 191 | if modality == "video": 192 | video_processed = [] 193 | for idx, frame in enumerate(video): 194 | image_processor.do_resize = False 195 | image_processor.do_center_crop = False 196 | frame = process_anyres_video(frame, image_processor) 197 | 198 | if frame_idx is not None and idx in frame_idx: 199 | video_processed.append(frame.unsqueeze(0)) 200 | elif frame_idx is None: 201 | video_processed.append(frame.unsqueeze(0)) 202 | 203 | if frame_idx is None: 204 | frame_idx = np.arange(0, len(video_processed), dtype=int).tolist() 205 | 206 | video_processed = torch.cat(video_processed, dim=0).bfloat16().to("cuda") 207 | video_processed = (video_processed, video_processed) 208 | 209 | video_data = (video_processed, (384, 384), "video") 210 | elif modality == "image": 211 | image_processor.do_resize = False 212 | image_processor.do_center_crop = False 213 | image_tensor, image_highres_tensor = [], [] 214 | for visual in image: 215 | image_tensor_, image_highres_tensor_ = process_anyres_highres_image(visual, image_processor) 216 | image_tensor.append(image_tensor_) 217 | image_highres_tensor.append(image_highres_tensor_) 218 | if all(x.shape == image_tensor[0].shape for x in image_tensor): 219 | image_tensor = torch.stack(image_tensor, dim=0) 220 | if all(x.shape == image_highres_tensor[0].shape for x in image_highres_tensor): 221 | image_highres_tensor = torch.stack(image_highres_tensor, dim=0) 222 | if type(image_tensor) is list: 223 | image_tensor = [_image.bfloat16().to("cuda") for _image in image_tensor] 224 | else: 225 | image_tensor = image_tensor.bfloat16().to("cuda") 226 | if type(image_highres_tensor) is list: 227 | image_highres_tensor = [_image.bfloat16().to("cuda") for _image in image_highres_tensor] 228 | else: 229 | image_highres_tensor = image_highres_tensor.bfloat16().to("cuda") 230 | 231 | pad_token_ids = 151643 232 | 233 | attention_masks = input_ids.ne(pad_token_ids).long().to('cuda') 234 | stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 235 | keywords = [stop_str] 236 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 237 | 238 | gen_kwargs = {} 239 | 240 | if "max_new_tokens" not in gen_kwargs: 241 | gen_kwargs["max_new_tokens"] = 1024 242 | if "temperature" not in gen_kwargs: 243 | gen_kwargs["temperature"] = 0.2 244 | if "top_p" not in gen_kwargs: 245 | gen_kwargs["top_p"] = None 246 | if "num_beams" not in gen_kwargs: 247 | gen_kwargs["num_beams"] = 1 248 | 249 | with torch.inference_mode(): 250 | if modality == "video": 251 | output_ids = model.generate( 252 | inputs=input_ids, 253 | images=video_data[0][0], 254 | images_highres=video_data[0][1], 255 | modalities=video_data[2], 256 | speech=speechs, 257 | speech_lengths=speech_lengths, 258 | speech_chunks=speech_chunks, 259 | speech_wav=speech_wavs, 260 | attention_mask=attention_masks, 261 | use_cache=True, 262 | stopping_criteria=[stopping_criteria], 263 | do_sample=True if gen_kwargs["temperature"] > 0 else False, 264 | temperature=gen_kwargs["temperature"], 265 | top_p=gen_kwargs["top_p"], 266 | num_beams=gen_kwargs["num_beams"], 267 | max_new_tokens=gen_kwargs["max_new_tokens"], 268 | ) 269 | elif modality == "image": 270 | output_ids = model.generate( 271 | inputs=input_ids, 272 | images=image_tensor, 273 | images_highres=image_highres_tensor, 274 | image_sizes=image_sizes, 275 | modalities=['image'], 276 | speech=speechs, 277 | speech_lengths=speech_lengths, 278 | speech_chunks=speech_chunks, 279 | speech_wav=speech_wavs, 280 | attention_mask=attention_masks, 281 | use_cache=True, 282 | stopping_criteria=[stopping_criteria], 283 | do_sample=True if gen_kwargs["temperature"] > 0 else False, 284 | temperature=gen_kwargs["temperature"], 285 | top_p=gen_kwargs["top_p"], 286 | num_beams=gen_kwargs["num_beams"], 287 | max_new_tokens=gen_kwargs["max_new_tokens"], 288 | ) 289 | elif modality == "text": 290 | output_ids = model.generate( 291 | input_ids, 292 | images=images, 293 | images_highres=images_highres, 294 | image_sizes=image_sizes, 295 | modalities=['text'], 296 | speech=speechs, 297 | speech_lengths=speech_lengths, 298 | speech_chunks=speech_chunks, 299 | speech_wav=speech_wavs, 300 | attention_mask=attention_masks, 301 | use_cache=True, 302 | stopping_criteria=[stopping_criteria], 303 | do_sample=True if gen_kwargs["temperature"] > 0 else False, 304 | temperature=gen_kwargs["temperature"], 305 | top_p=gen_kwargs["top_p"], 306 | num_beams=gen_kwargs["num_beams"], 307 | max_new_tokens=gen_kwargs["max_new_tokens"], 308 | ) 309 | 310 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0] 311 | outputs = outputs.strip() 312 | if outputs.endswith(stop_str): 313 | outputs = outputs[:-len(stop_str)] 314 | outputs = outputs.strip() 315 | 316 | print(outputs) -------------------------------------------------------------------------------- /ola/arguments.py: -------------------------------------------------------------------------------- 1 | import transformers 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Optional 5 | 6 | 7 | @dataclass 8 | class ModelArguments: 9 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 10 | version: Optional[str] = field(default="v0") 11 | freeze_backbone: bool = field(default=False) 12 | tune_speech_projector: bool = field(default=False) 13 | tune_speech_encoder: bool = field(default=False) 14 | tune_speech_generator_only: bool = field(default=False) 15 | speech_encoder_type: Optional[str] = field(default=None) 16 | speech_encoder: Optional[str] = field(default=None) 17 | pretrain_speech_projector: Optional[str] = field(default=None) 18 | speech_projector_type: Optional[str] = field(default='linear') 19 | speech_encoder_ds_rate: int = 5 20 | speech_encoder_hidden_size: int = 1280 21 | 22 | 23 | @dataclass 24 | class DataArguments: 25 | data_path: str = field(default=None, 26 | metadata={"help": "Path to the training data."}) 27 | is_multimodal: bool = False 28 | input_type: str = field(default="mel") 29 | speech_normalize: bool = False 30 | mel_size: int = 128 31 | has_tgt_units: bool = False 32 | 33 | 34 | @dataclass 35 | class TrainingArguments(transformers.TrainingArguments): 36 | cache_dir: Optional[str] = field(default=None) 37 | optim: str = field(default="adamw_torch") 38 | freeze_speech_projector: bool = field(default=False) 39 | model_max_length: int = field( 40 | default=512, 41 | metadata={ 42 | "help": 43 | "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 44 | }, 45 | ) 46 | double_quant: bool = field( 47 | default=True, 48 | metadata={"help": "Compress the quantization statistics through double quantization."} 49 | ) 50 | quant_type: str = field( 51 | default="nf4", 52 | metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} 53 | ) 54 | bits: int = field( 55 | default=16, 56 | metadata={"help": "How many bits to use."} 57 | ) 58 | lora_enable: bool = False 59 | lora_r: int = 64 60 | lora_alpha: int = 16 61 | lora_dropout: float = 0.05 62 | lora_weight_path: str = "" 63 | lora_bias: str = "none" 64 | speech_projector_lr: Optional[float] = None 65 | group_by_modality_length: bool = field(default=False) -------------------------------------------------------------------------------- /ola/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | SPEECH_TOKEN_INDEX = -200 9 | DEFAULT_SPEECH_TOKEN = "" 10 | IMAGE_TOKEN_INDEX= -300 11 | DEFAULT_IMAGE_TOKEN = "" 12 | DEFAULT_IMAGE_PATCH_TOKEN = "" 13 | DEFAULT_IM_START_TOKEN = "" 14 | DEFAULT_IM_END_TOKEN = "" -------------------------------------------------------------------------------- /ola/conversation.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from enum import auto, Enum 3 | from typing import List, Any, Union, Tuple 4 | import base64 5 | from io import BytesIO 6 | from PIL import Image 7 | 8 | 9 | class SeparatorStyle(Enum): 10 | """Different separator style.""" 11 | TWO = auto() 12 | PLAIN = auto() 13 | CHATML = auto() 14 | LLAMA_2 = auto() 15 | LLAMA_3 = auto() 16 | QWEN2 = auto() 17 | 18 | 19 | @dataclasses.dataclass 20 | class Conversation: 21 | """A class that keeps all conversation history.""" 22 | system: str 23 | roles: List[str] 24 | messages: List[List[str]] 25 | offset: int 26 | sep_style: SeparatorStyle = SeparatorStyle.PLAIN 27 | sep: str = "###" 28 | sep2: str = None 29 | version: str = "Unknown" 30 | 31 | tokenizer_id: str = "" 32 | tokenizer: Any = None 33 | # Stop criteria (the default one is EOS token) 34 | stop_str: Union[str, List[str]] = None 35 | # Stops generation if meeting any token in this list 36 | stop_token_ids: List[int] = None 37 | 38 | skip_next: bool = False 39 | 40 | def get_prompt(self): 41 | messages = self.messages 42 | 43 | if self.sep_style == SeparatorStyle.TWO: 44 | seps = [self.sep, self.sep2] 45 | ret = self.system + seps[0] 46 | for i, (role, message) in enumerate(messages): 47 | if message: 48 | if type(message) is tuple: 49 | message = message[0] 50 | ret += role + ": " + message + seps[i % 2] 51 | else: 52 | ret += role + ":" 53 | elif self.sep_style == SeparatorStyle.LLAMA_3: 54 | wrap_sys = lambda msg: f"<|start_header_id|>system<|end_header_id|>\n\n{msg}<|eot_id|>" if len(msg) > 0 else msg 55 | ret = "<|begin_of_text|>" + wrap_sys(self.system) 56 | for i, (role, message) in enumerate(messages): 57 | if message: 58 | if type(message) is tuple: 59 | message = message[0] 60 | ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" 61 | ret += message.strip() + self.sep2 62 | else: 63 | ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" 64 | return ret 65 | elif self.sep_style == SeparatorStyle.LLAMA_2: 66 | wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg 67 | wrap_inst = lambda msg: f"[INST] {msg} [/INST]" 68 | ret = "" 69 | 70 | for i, (role, message) in enumerate(messages): 71 | if i == 0: 72 | assert message, "first message should not be none" 73 | assert role == self.roles[0], "first message should come from user" 74 | if message: 75 | if type(message) is tuple: 76 | message, _, _ = message 77 | if i == 0: 78 | message = wrap_sys(self.system) + message 79 | if i % 2 == 0: 80 | message = wrap_inst(message) 81 | ret += self.sep + message 82 | else: 83 | ret += " " + message + " " + self.sep2 84 | else: 85 | ret += "" 86 | ret = ret.lstrip(self.sep) 87 | elif self.sep_style == SeparatorStyle.PLAIN: 88 | seps = [self.sep, self.sep2] 89 | ret = self.system 90 | for i, (role, message) in enumerate(messages): 91 | if message: 92 | if type(message) is tuple: 93 | message, _, _ = message 94 | ret += message + seps[i % 2] 95 | else: 96 | ret += "" 97 | 98 | elif self.sep_style == SeparatorStyle.CHATML: 99 | ret = "" if self.system == "" else self.system + self.sep + "\n" 100 | for role, message in messages: 101 | if message: 102 | if type(message) is tuple: 103 | raise ValueError("Tuple not supported in CHATML") 104 | message, images = message 105 | message = "" * len(images) + message 106 | ret += role + "\n" + message + self.sep + "\n" 107 | else: 108 | ret += role + "\n" 109 | return ret 110 | elif self.sep_style == SeparatorStyle.QWEN2: 111 | start = '<|im_start|>' 112 | end = '<|im_end|>\n' 113 | ret = start + 'system\n' + self.system + end 114 | for i, (role, message) in enumerate(messages): 115 | if message: 116 | if type(message) is tuple: 117 | message, _, _ = message 118 | 119 | if message.endswith('<|endoftext|>'): 120 | message = message.replace('<|endoftext|>', '') 121 | ret += start + role + "\n" + message + end + '<|endoftext|>' 122 | else: 123 | assert not '<|endoftext|>' in message, f"Invalid message: {message}" 124 | ret += start + role + "\n" + message + end 125 | else: 126 | ret += start + role + "\n" 127 | else: 128 | raise ValueError(f"Invalid style: {self.sep_style}") 129 | 130 | return ret 131 | 132 | def append_message(self, role, message): 133 | self.messages.append([role, message]) 134 | 135 | def to_gradio_chatbot(self): 136 | ret = [] 137 | for i, (role, msg) in enumerate(self.messages[self.offset:]): 138 | if i % 2 == 0: 139 | if type(msg) is tuple: 140 | msg, speech = msg 141 | ret.append([msg, None]) 142 | else: 143 | ret.append([msg, None]) 144 | else: 145 | ret[-1][-1] = msg 146 | return ret 147 | 148 | def copy(self): 149 | return Conversation( 150 | system=self.system, 151 | roles=self.roles, 152 | messages=[[x, y] for x, y in self.messages], 153 | offset=self.offset, 154 | sep_style=self.sep_style, 155 | sep=self.sep, 156 | sep2=self.sep2, 157 | version=self.version) 158 | 159 | def dict(self): 160 | if len(self.get_images()) > 0: 161 | return { 162 | "system": self.system, 163 | "roles": self.roles, 164 | "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], 165 | "offset": self.offset, 166 | "sep": self.sep, 167 | "sep2": self.sep2, 168 | } 169 | return { 170 | "system": self.system, 171 | "roles": self.roles, 172 | "messages": self.messages, 173 | "offset": self.offset, 174 | "sep": self.sep, 175 | "sep2": self.sep2, 176 | } 177 | 178 | conv_vicuna_v1 = Conversation( 179 | system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", 180 | roles=("USER", "ASSISTANT"), 181 | version="v1", 182 | messages=[], 183 | offset=0, 184 | sep_style=SeparatorStyle.TWO, 185 | sep=" ", 186 | sep2="", 187 | ) 188 | 189 | conv_llama_2 = Conversation( 190 | system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", 191 | roles=("USER", "ASSISTANT"), 192 | version="llama_v2", 193 | messages=[], 194 | offset=0, 195 | sep_style=SeparatorStyle.LLAMA_2, 196 | sep="", 197 | sep2="", 198 | ) 199 | 200 | conv_llama_3 = Conversation( 201 | system="You are a helpful language and speech assistant. " "You are able to understand the speech content that the user provides, " "and assist the user with a variety of tasks using natural language.", 202 | roles=("user", "assistant"), 203 | version="llama_v3", 204 | messages=[], 205 | offset=0, 206 | sep_style=SeparatorStyle.LLAMA_3, 207 | sep="", 208 | sep2="<|eot_id|>" 209 | ) 210 | 211 | 212 | conv_qwen_v1 = Conversation( 213 | system="You are a helpful assistant.", 214 | roles=("user", "assistant"), 215 | version="v1", 216 | messages=(), 217 | offset=0, 218 | sep_style=SeparatorStyle.QWEN2, 219 | ) 220 | 221 | conv_plain = Conversation( 222 | system="", 223 | roles=("", ""), 224 | messages=( 225 | ), 226 | offset=0, 227 | sep_style=SeparatorStyle.PLAIN, 228 | sep="", 229 | ) 230 | 231 | conv_qwen = Conversation( 232 | system="""<|im_start|>system 233 | You are a helpful assistant.""", 234 | roles=("<|im_start|>user", "<|im_start|>assistant"), 235 | version="qwen", 236 | messages=[], 237 | offset=0, 238 | sep_style=SeparatorStyle.CHATML, 239 | sep="<|im_end|>", 240 | ) 241 | 242 | default_conversation = conv_llama_3 243 | conv_templates = { 244 | "v1": conv_vicuna_v1, 245 | "plain": conv_plain, 246 | "llama_2": conv_llama_2, 247 | "llama_3": conv_llama_3, 248 | 'v1_qwen2': conv_qwen_v1, 249 | "qwen_1_5": conv_qwen, 250 | } 251 | 252 | 253 | if __name__ == "__main__": 254 | print(default_conversation.get_prompt()) 255 | -------------------------------------------------------------------------------- /ola/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/datasets/__init__.py -------------------------------------------------------------------------------- /ola/datasets/preprocess.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import transformers 4 | import tokenizers 5 | 6 | from typing import Dict, Sequence 7 | 8 | from ola.constants import IGNORE_INDEX, DEFAULT_SPEECH_TOKEN, IMAGE_TOKEN_INDEX 9 | from ola import conversation as conversation_lib 10 | from ola.model import * 11 | from ola.arguments import DataArguments 12 | from ola.constants import SPEECH_TOKEN_INDEX 13 | 14 | from packaging import version 15 | 16 | IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') 17 | 18 | 19 | def tokenizer_speech_token(prompt, tokenizer, speech_token_index=SPEECH_TOKEN_INDEX, return_tensors=None): 20 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 21 | 22 | def insert_separator(X, sep): 23 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 24 | 25 | input_ids = [] 26 | offset = 0 27 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 28 | offset = 1 29 | input_ids.append(prompt_chunks[0][0]) 30 | 31 | for x in insert_separator(prompt_chunks, [speech_token_index] * (offset + 1)): 32 | input_ids.extend(x[offset:]) 33 | 34 | if return_tensors is not None: 35 | if return_tensors == 'pt': 36 | return torch.tensor(input_ids, dtype=torch.long) 37 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 38 | return input_ids 39 | 40 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 41 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 42 | 43 | def insert_separator(X, sep): 44 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 45 | 46 | input_ids = [] 47 | offset = 0 48 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 49 | offset = 1 50 | input_ids.append(prompt_chunks[0][0]) 51 | 52 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 53 | input_ids.extend(x[offset:]) 54 | 55 | if return_tensors is not None: 56 | if return_tensors == 'pt': 57 | return torch.tensor(input_ids, dtype=torch.long) 58 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 59 | return input_ids 60 | 61 | def tokenizer_speech_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None): 62 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 63 | 64 | def insert_separator(X, sep): 65 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 66 | 67 | input_ids = [] 68 | offset = 0 69 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 70 | offset = 1 71 | input_ids.append(prompt_chunks[0][0]) 72 | 73 | for x in insert_separator(prompt_chunks, [speech_token_idx, image_token_index] * (offset + 1)): 74 | input_ids.extend(x[offset:]) 75 | 76 | if return_tensors is not None: 77 | if return_tensors == 'pt': 78 | return torch.tensor(input_ids, dtype=torch.long) 79 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 80 | return input_ids 81 | 82 | def tokenizer_speech_question_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, speech_token_idx=SPEECH_TOKEN_INDEX, return_tensors=None): 83 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("\nUser's question in speech: \n")] 84 | 85 | def insert_separator(X, sep): 86 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 87 | 88 | input_ids = [] 89 | offset = 0 90 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 91 | offset = 1 92 | input_ids.append(prompt_chunks[0][0]) 93 | 94 | nl_tokens = tokenizer("\n").input_ids 95 | special_chunks = [image_token_index, nl_tokens, tokenizer("User's question in speech: ").input_ids, speech_token_idx, nl_tokens] 96 | 97 | for x in insert_separator(prompt_chunks, [special_chunks] * (offset + 1)): 98 | input_ids.extend(x[offset:]) 99 | 100 | if return_tensors is not None: 101 | if return_tensors == 'pt': 102 | return torch.tensor(input_ids, dtype=torch.long) 103 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 104 | return input_ids 105 | 106 | def preprocess_v1( 107 | sources, 108 | tokenizer: transformers.PreTrainedTokenizer, 109 | has_speech: bool = False 110 | ) -> Dict: 111 | conv = conversation_lib.default_conversation.copy() 112 | roles = {"human": conv.roles[0], "gpt": conv.roles[1]} 113 | 114 | # Apply prompt templates 115 | conversations = [] 116 | for i, source in enumerate(sources): 117 | if roles[source[0]["from"]] != conv.roles[0]: 118 | # Skip the first one if it is not from human 119 | source = source[1:] 120 | 121 | conv.messages = [] 122 | for j, sentence in enumerate(source): 123 | role = roles[sentence["from"]] 124 | assert role == conv.roles[j % 2], f"{i}" 125 | conv.append_message(role, sentence["value"]) 126 | conversations.append(conv.get_prompt()) 127 | 128 | # Tokenize conversations 129 | 130 | if has_speech: 131 | input_ids = torch.stack([tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations], dim=0) 132 | else: 133 | input_ids = tokenizer( 134 | conversations, 135 | return_tensors="pt", 136 | padding="longest", 137 | max_length=tokenizer.model_max_length, 138 | truncation=True, 139 | ).input_ids 140 | 141 | targets = input_ids.clone() 142 | 143 | assert conv.sep_style == conversation_lib.SeparatorStyle.TWO 144 | 145 | # Mask targets 146 | sep = conv.sep + conv.roles[1] + ": " 147 | for conversation, target in zip(conversations, targets): 148 | total_len = int(target.ne(tokenizer.pad_token_id).sum()) 149 | 150 | rounds = conversation.split(conv.sep2) 151 | cur_len = 1 152 | target[:cur_len] = IGNORE_INDEX 153 | for i, rou in enumerate(rounds): 154 | if rou == "": 155 | break 156 | 157 | parts = rou.split(sep) 158 | if len(parts) != 2: 159 | break 160 | parts[0] += sep 161 | 162 | if has_speech: 163 | round_len = len(tokenizer_speech_token(rou, tokenizer)) 164 | instruction_len = len(tokenizer_speech_token(parts[0], tokenizer)) - 2 165 | else: 166 | round_len = len(tokenizer(rou).input_ids) 167 | instruction_len = len(tokenizer(parts[0]).input_ids) - 2 168 | 169 | # FIXME: tokenizer bug 170 | if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: 171 | round_len -= 1 172 | instruction_len -= 1 173 | 174 | target[cur_len : cur_len + instruction_len] = IGNORE_INDEX 175 | 176 | cur_len += round_len 177 | target[cur_len:] = IGNORE_INDEX 178 | 179 | if cur_len < tokenizer.model_max_length: 180 | if cur_len != total_len: 181 | target[:] = IGNORE_INDEX 182 | print( 183 | f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." 184 | f" (ignored)" 185 | ) 186 | 187 | return dict( 188 | input_ids=input_ids, 189 | labels=targets, 190 | ) 191 | 192 | 193 | def preprocess_plain( 194 | sources: Sequence[str], 195 | tokenizer: transformers.PreTrainedTokenizer, 196 | ) -> Dict: 197 | # add end signal and concatenate together 198 | conversations = [] 199 | for source in sources: 200 | assert len(source) == 2 201 | assert DEFAULT_SPEECH_TOKEN in source[0]['value'] 202 | source[0]['value'] = DEFAULT_SPEECH_TOKEN 203 | conversation = source[0]['value'] + source[1]['value'] + conversation_lib.default_conversation.sep 204 | conversations.append(conversation) 205 | # tokenize conversations 206 | input_ids = [tokenizer_speech_token(prompt, tokenizer, return_tensors='pt') for prompt in conversations] 207 | targets = copy.deepcopy(input_ids) 208 | for target, source in zip(targets, sources): 209 | tokenized_len = len(tokenizer_speech_token(source[0]['value'], tokenizer)) 210 | target[:tokenized_len] = IGNORE_INDEX 211 | 212 | return dict(input_ids=input_ids, labels=targets) 213 | 214 | 215 | def preprocess( 216 | sources: Sequence[str], 217 | tokenizer: transformers.PreTrainedTokenizer, 218 | has_speech: bool = False 219 | ) -> Dict: 220 | """ 221 | Given a list of sources, each is a conversation list. This transform: 222 | 1. Add signal '### ' at the beginning each sentence, with end signal '\n'; 223 | 2. Concatenate conversations together; 224 | 3. Tokenize the concatenated conversation; 225 | 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. 226 | """ 227 | if conversation_lib.default_conversation.sep_style == conversation_lib.SeparatorStyle.PLAIN: 228 | return preprocess_plain(sources, tokenizer) 229 | if conversation_lib.default_conversation.version.startswith("v1"): 230 | return preprocess_v1(sources, tokenizer, has_speech=has_speech) 231 | raise NotImplementedError -------------------------------------------------------------------------------- /ola/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import base64 3 | import math 4 | import ast 5 | 6 | import torch 7 | from transformers import StoppingCriteria 8 | import os 9 | import io 10 | 11 | if 'VIDEO_RESIZE' in os.environ: 12 | # highresxpatch 13 | VIDEO_RESIZE = os.environ['VIDEO_RESIZE'] 14 | video_base, video_ps = VIDEO_RESIZE.split('x') 15 | video_base = int(video_base) 16 | video_ps = int(video_ps) 17 | print(f"VIDEO_RESIZE is set as {VIDEO_RESIZE}, {video_base}, {video_ps}") 18 | else: 19 | HIGHRES_BASE = None 20 | 21 | if 'HIGHRES_BASE' in os.environ: 22 | # highresxpatch 23 | HIGHRES_BASE = os.environ['HIGHRES_BASE'] 24 | highres_base, highres_ps = HIGHRES_BASE.split('x') 25 | highres_base = int(highres_base) 26 | highres_ps = int(highres_ps) 27 | print(f"HIGHRES_BASE is set as {HIGHRES_BASE}, {highres_base}, {highres_ps}") 28 | else: 29 | HIGHRES_BASE = None 30 | 31 | if 'MAXRES' in os.environ: 32 | # highresxpatch 33 | MAXRES = int(os.environ['MAXRES']) 34 | print(f"MAXRES is set as {MAXRES}") 35 | else: 36 | MAXRES = 1536 37 | 38 | if 'MINRES' in os.environ: 39 | # highresxpatch 40 | MINRES = int(os.environ['MINRES']) 41 | print(f"MINRES is set as {MINRES}") 42 | else: 43 | MINRES = 0 44 | 45 | if 'VIDEO_MAXRES' in os.environ: 46 | # highresxpatch 47 | VIDEO_MAXRES = int(os.environ['VIDEO_MAXRES']) 48 | print(f"VIDEO_MAXRES is set as {VIDEO_MAXRES}") 49 | else: 50 | VIDEO_MAXRES = 1536 51 | 52 | if 'VIDEO_MINRES' in os.environ: 53 | # highresxpatch 54 | VIDEO_MINRES = int(os.environ['VIDEO_MINRES']) 55 | print(f"VIDEO_MINRES is set as {VIDEO_MINRES}") 56 | else: 57 | MINRES = 0 58 | 59 | if 'PAD2STRIDE' in os.environ: 60 | # highresxpatch 61 | PAD2STRIDE = True 62 | print(f"PAD2STRIDE is set") 63 | else: 64 | PAD2STRIDE = False 65 | 66 | if 'LOWRES_RESIZE' in os.environ: 67 | LOWRES_RESIZE = os.environ['LOWRES_RESIZE'] 68 | print(f"LOWRES_RESIZE is set as {LOWRES_RESIZE}") 69 | if 'x' in LOWRES_RESIZE: 70 | size, ps = LOWRES_RESIZE.split('x') 71 | size = int(size) 72 | ps = int(ps) 73 | LOWRES_RESIZE = (size, ps) 74 | else: 75 | LOWRES_RESIZE = int(LOWRES_RESIZE) 76 | else: 77 | LOWRES_RESIZE = None 78 | 79 | 80 | def pad_image(image, target_resolution, value=0): 81 | """ 82 | Resize and pad an image to a target resolution while maintaining aspect ratio. 83 | 84 | Args: 85 | image (PIL.Image.Image): The input image. 86 | target_resolution (tuple): The target resolution (width, height) of the image. 87 | 88 | Returns: 89 | PIL.Image.Image: The resized and padded image. 90 | """ 91 | original_width, original_height = image.size 92 | target_width, target_height = target_resolution 93 | # Create a new image with the target size and paste the resized image onto it 94 | new_image = Image.new('RGB', (target_width, target_height), (value, value, value)) 95 | paste_x = (target_width - original_width) // 2 96 | paste_y = (target_height - original_height) // 2 97 | new_image.paste(image, (paste_x, paste_y)) 98 | return new_image 99 | 100 | def resize_images(image, patch_size=14, base_size=896): 101 | h, w = image.size 102 | if base_size == 0: 103 | if h * w > MAXRES * MAXRES: 104 | # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}') 105 | scale = MAXRES * MAXRES / (h * w) 106 | scale = math.sqrt(scale) 107 | elif h * w < MINRES * MINRES: 108 | # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}') 109 | scale = MINRES * MINRES / (h * w) 110 | scale = math.sqrt(scale) 111 | else: 112 | scale = None 113 | else: 114 | scale = base_size * base_size / (h * w) 115 | scale = math.sqrt(scale) 116 | 117 | 118 | if scale is not None: 119 | new_h = int(h * scale / patch_size) * patch_size 120 | new_w = int(w * scale / patch_size) * patch_size 121 | new_h = max(new_h, patch_size) 122 | new_w = max(new_w, patch_size) 123 | image = image.resize((new_h, new_w)) 124 | elif PAD2STRIDE: 125 | if h % patch_size == 0: 126 | new_h = h 127 | else: 128 | new_h = (h // patch_size + 1) * patch_size 129 | 130 | if w % patch_size == 0: 131 | new_w = w 132 | else: 133 | new_w = (w // patch_size + 1) * patch_size 134 | image = pad_image(image, (new_h, new_w), value=127) 135 | else: 136 | scale = 1.0 137 | new_h = int(h * scale / patch_size) * patch_size 138 | new_w = int(w * scale / patch_size) * patch_size 139 | new_h = max(new_h, patch_size) 140 | new_w = max(new_w, patch_size) 141 | image = image.resize((new_h, new_w)) 142 | 143 | return image 144 | 145 | def resize_video(image, patch_size=14, base_size=896): 146 | h, w = image.size 147 | if base_size == 0: 148 | if h * w > VIDEO_MAXRES * VIDEO_MAXRES: 149 | # print(f'{h}x{w} larger than max size {MAXRES}, resize to {MAXRES}') 150 | scale = VIDEO_MAXRES * VIDEO_MAXRES / (h * w) 151 | scale = math.sqrt(scale) 152 | elif h * w < VIDEO_MINRES * VIDEO_MINRES: 153 | # print(f'{h}x{w} smaller than max size {MINRES}, resize to {MINRES}') 154 | scale = VIDEO_MINRES * VIDEO_MINRES / (h * w) 155 | scale = math.sqrt(scale) 156 | else: 157 | scale = None 158 | else: 159 | scale = base_size * base_size / (h * w) 160 | scale = math.sqrt(scale) 161 | 162 | if scale is not None: 163 | new_h = int(h * scale / patch_size) * patch_size 164 | new_w = int(w * scale / patch_size) * patch_size 165 | image = image.resize((new_h, new_w)) 166 | elif PAD2STRIDE: 167 | if h % patch_size == 0: 168 | new_h = h 169 | else: 170 | new_h = (h // patch_size + 1) * patch_size 171 | 172 | if w % patch_size == 0: 173 | new_w = w 174 | else: 175 | new_w = (w // patch_size + 1) * patch_size 176 | image = pad_image(image, (new_h, new_w), value=127) 177 | else: 178 | scale = 1.0 179 | new_h = int(h * scale / patch_size) * patch_size 180 | new_w = int(w * scale / patch_size) * patch_size 181 | image = image.resize((new_h, new_w)) 182 | 183 | return image 184 | 185 | def process_anyres_video(image, processor): 186 | if VIDEO_RESIZE is not None: 187 | image = resize_video(image, patch_size=video_ps, base_size=video_base) 188 | image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 189 | return image.unsqueeze(0) 190 | else: 191 | raise ValueError("VIDEO_RESIZE is not set") 192 | 193 | def process_anyres_highres_image(image, processor): 194 | processor2 = None 195 | if type(processor) is tuple: 196 | processor, processor2 = processor[0], processor[1] 197 | 198 | if HIGHRES_BASE is not None: 199 | image = resize_images(image, patch_size=highres_ps, base_size=highres_base) 200 | 201 | if processor2 is not None: 202 | image_original_resize = image.resize((processor2.size['shortest_edge'], processor.size['shortest_edge'])) 203 | image_patches = [image_original_resize] + [image_original_resize] 204 | image_patches = [processor2.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] 205 | for image_patch in image_patches] 206 | else: 207 | if LOWRES_RESIZE is not None: 208 | if type(LOWRES_RESIZE) is int: 209 | image_original_resize = resize_images(image, patch_size=14, base_size=LOWRES_RESIZE) 210 | else: 211 | image_original_resize = resize_images(image, patch_size=LOWRES_RESIZE[1], base_size=LOWRES_RESIZE[0]) 212 | else: 213 | image_original_resize = image.resize((336, 336)) 214 | image_patches = [image_original_resize] 215 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] 216 | for image_patch in image_patches] 217 | image_padded = processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 218 | return torch.stack(image_patches, dim=0), image_padded.unsqueeze(0) 219 | 220 | def read_image_patch(patch_info): 221 | if 'img_path' in patch_info.keys(): 222 | image = Image.open(patch_info['img_path']).convert('RGB') 223 | else: 224 | if 'image_encoing' in patch_info.keys(): 225 | patch_info['image_encoding'] = patch_info['image_encoing'] 226 | image_file_name = patch_info['patch'] 227 | start_bytes = int(patch_info['start_num']) 228 | file_size = int(patch_info['size']) 229 | 230 | with open(image_file_name, 'rb') as f: 231 | f.seek(start_bytes) 232 | if 'image_encoding' in patch_info.keys() and patch_info['image_encoding'] == 'base64': 233 | image = Image.open(io.BytesIO(base64.b64decode(f.read(file_size).decode()))).convert("RGB") 234 | else: 235 | image = Image.open(io.BytesIO(f.read(file_size))).convert("RGB") 236 | return image 237 | 238 | 239 | def get_model_name_from_path(model_path): 240 | model_path = model_path.strip("/") 241 | model_paths = model_path.split("/") 242 | if model_paths[-1].startswith('checkpoint-'): 243 | return model_paths[-2] + "_" + model_paths[-1] 244 | else: 245 | return model_paths[-1] 246 | 247 | 248 | class KeywordsStoppingCriteria(StoppingCriteria): 249 | def __init__(self, keywords, tokenizer, input_ids): 250 | self.keywords = keywords 251 | self.keyword_ids = [] 252 | for keyword in keywords: 253 | cur_keyword_ids = tokenizer(keyword).input_ids 254 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 255 | cur_keyword_ids = cur_keyword_ids[1:] 256 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 257 | self.tokenizer = tokenizer 258 | self.start_len = input_ids.shape[1] 259 | 260 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 261 | assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO 262 | offset = min(output_ids.shape[1] - self.start_len, 3) 263 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 264 | for keyword_id in self.keyword_ids: 265 | if output_ids[0, -keyword_id.shape[0]:] == keyword_id: 266 | return True 267 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 268 | for keyword in self.keywords: 269 | if keyword in outputs: 270 | return True 271 | return False 272 | -------------------------------------------------------------------------------- /ola/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .language_model.ola_qwen import OlaQwenForCausalLM, OlaConfigQwen -------------------------------------------------------------------------------- /ola/model/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | import shutil 4 | 5 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 6 | import torch 7 | from ola.model import * 8 | from ola.model.speech_encoder.builder import build_speech_encoder 9 | 10 | def load_pretrained_model(model_path, model_base, is_lora=False, s2s=False, load_8bit=False, load_4bit=False, device="cuda", use_flash_attn=False, **kwargs): 11 | if load_8bit: 12 | kwargs['load_in_8bit'] = True 13 | elif load_4bit: 14 | kwargs['load_in_4bit'] = True 15 | kwargs['quantization_config'] = BitsAndBytesConfig( 16 | load_in_4bit=True, 17 | bnb_4bit_compute_dtype=torch.float16, 18 | bnb_4bit_use_double_quant=True, 19 | bnb_4bit_quant_type='nf4' 20 | ) 21 | else: 22 | kwargs['torch_dtype'] = torch.bfloat16 23 | 24 | if use_flash_attn: 25 | kwargs['attn_implementation'] = 'flash_attention_2' 26 | 27 | model_cls = OlaQwenForCausalLM 28 | 29 | # Load Ola model 30 | if is_lora: 31 | assert model_base is not None, "model_base is required for LoRA models." 32 | from ola.model.language_model.ola_qwen import OlaConfigQwen 33 | lora_cfg_pretrained = OlaConfigQwen.from_pretrained(model_path) 34 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 35 | print('Loading Ola from base model...') 36 | model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=lora_cfg_pretrained, **kwargs) 37 | print('Loading additional Ola weights...') 38 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 39 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 40 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 41 | if any(k.startswith('model.model.') for k in non_lora_trainables): 42 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 43 | model.load_state_dict(non_lora_trainables, strict=False) 44 | 45 | from peft import PeftModel 46 | print('Loading LoRA weights...') 47 | model = PeftModel.from_pretrained(model, model_path) 48 | print('Merging LoRA weights...') 49 | model = model.merge_and_unload() 50 | print('Model is loaded...') 51 | elif model_base is not None: 52 | print('Loading Ola from base model...') 53 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 54 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 55 | model = model_cls.from_pretrained(model_base, low_cpu_mem_usage=False, config=cfg_pretrained, **kwargs) 56 | 57 | speech_projector_weights = torch.load(os.path.join(model_path, 'speech_projector.bin'), map_location='cpu') 58 | speech_projector_weights = {k: v.to(torch.float16) for k, v in speech_projector_weights.items()} 59 | model.load_state_dict(speech_projector_weights, strict=False) 60 | model = model.to(device=device) 61 | else: 62 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 63 | model = model_cls.from_pretrained( 64 | model_path, 65 | low_cpu_mem_usage=False, 66 | **kwargs 67 | ) 68 | model = model.to(device=device) 69 | 70 | model.get_model().speech_encoder = build_speech_encoder(model.config) 71 | model.get_model().speech_encoder.to(device=device, dtype=torch.float16) 72 | 73 | image_processor = None 74 | model.resize_token_embeddings(len(tokenizer)) 75 | vision_tower = model.get_vision_tower() 76 | print("Loading vision tower...") 77 | if not vision_tower.is_loaded: 78 | vision_tower.load_model(device_map=device) 79 | if device != "auto": 80 | vision_tower.to(device="cuda", dtype=torch.bfloat16) 81 | else: 82 | vision_tower.to(device="cuda:0", dtype=torch.bfloat16) 83 | image_processor = vision_tower.image_processor 84 | print("Loading vision tower succeeded.") 85 | 86 | if hasattr(model.config, "max_sequence_length"): 87 | context_len = model.config.max_sequence_length 88 | else: 89 | context_len = 16384 90 | 91 | return tokenizer, model, image_processor, context_len 92 | -------------------------------------------------------------------------------- /ola/model/language_model/ola_qwen.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | import transformers 7 | from transformers import AutoConfig, AutoModelForCausalLM 8 | 9 | 10 | from transformers.modeling_outputs import CausalLMOutputWithPast 11 | from transformers.generation.utils import GenerateOutput 12 | 13 | from ..ola_arch import OlaMetaModel, OlaMetaForCausalLM 14 | from transformers import Qwen2Config, Qwen2Model, Qwen2ForCausalLM 15 | 16 | 17 | class OlaConfigQwen(Qwen2Config): 18 | model_type = "ola_qwen" 19 | 20 | 21 | class OlaQwenModel(OlaMetaModel, Qwen2Model): 22 | config_class = OlaConfigQwen 23 | 24 | def __init__(self, config: Qwen2Config): 25 | super(OlaQwenModel, self).__init__(config) 26 | 27 | 28 | class OlaQwenForCausalLM(Qwen2ForCausalLM, OlaMetaForCausalLM): 29 | config_class = OlaConfigQwen 30 | 31 | def __init__(self, config): 32 | super(Qwen2ForCausalLM, self).__init__(config) 33 | 34 | config.rope_scaling = None 35 | self.model = OlaQwenModel(config) 36 | self.vocab_size = config.vocab_size 37 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 38 | 39 | # Initialize weights and apply final processing 40 | self.post_init() 41 | 42 | def get_model(self): 43 | return self.model 44 | 45 | def forward( 46 | self, 47 | input_ids: torch.LongTensor = None, 48 | attention_mask: Optional[torch.Tensor] = None, 49 | position_ids: Optional[torch.LongTensor] = None, 50 | past_key_values: Optional[List[torch.FloatTensor]] = None, 51 | inputs_embeds: Optional[torch.FloatTensor] = None, 52 | labels: Optional[torch.LongTensor] = None, 53 | use_cache: Optional[bool] = None, 54 | output_attentions: Optional[bool] = None, 55 | output_hidden_states: Optional[bool] = None, 56 | speech: Optional[torch.FloatTensor] = None, 57 | speech_lengths: Optional[torch.LongTensor] = None, 58 | speech_chunks: Optional[torch.LongTensor] = None, 59 | speech_wav: Optional[torch.FloatTensor] = None, 60 | images: Optional[torch.FloatTensor] = None, 61 | images_highres: Optional[List[torch.FloatTensor]] = None, 62 | image_sizes: Optional[List[List[int]]] = None, 63 | modalities: Optional[List[str]] = ["image"], 64 | return_dict: Optional[bool] = None, 65 | cache_position: Optional[torch.LongTensor] = None, 66 | ) -> Union[Tuple, CausalLMOutputWithPast]: 67 | 68 | if inputs_embeds is None: 69 | ( 70 | input_ids, 71 | position_ids, 72 | attention_mask, 73 | past_key_values, 74 | inputs_embeds, 75 | labels 76 | ) = self.prepare_inputs_labels_for_speech_vision_text( 77 | input_ids, 78 | position_ids, 79 | attention_mask, 80 | past_key_values, 81 | labels, 82 | speech, 83 | speech_lengths, 84 | speech_chunks, 85 | speech_wav, 86 | images, 87 | modalities, 88 | image_sizes, 89 | images_highres 90 | ) 91 | 92 | if labels is None: 93 | return super().forward( 94 | input_ids=input_ids, 95 | attention_mask=attention_mask, 96 | position_ids=position_ids, 97 | past_key_values=past_key_values, 98 | inputs_embeds=inputs_embeds, 99 | use_cache=use_cache, 100 | output_attentions=output_attentions, 101 | output_hidden_states=output_hidden_states, 102 | return_dict=return_dict 103 | ) 104 | else: 105 | return self.forward_llm_efficient( 106 | input_ids=input_ids, 107 | attention_mask=attention_mask, 108 | position_ids=position_ids, 109 | past_key_values=past_key_values, 110 | inputs_embeds=inputs_embeds, 111 | labels=labels, 112 | use_cache=use_cache, 113 | output_attentions=output_attentions, 114 | output_hidden_states=output_hidden_states, 115 | return_dict=return_dict 116 | ) 117 | 118 | 119 | def forward_llm_efficient(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict): 120 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 121 | output_hidden_states = ( 122 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 123 | ) 124 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 125 | 126 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 127 | outputs = self.model( 128 | input_ids=input_ids, 129 | attention_mask=attention_mask, 130 | position_ids=position_ids, 131 | past_key_values=past_key_values, 132 | inputs_embeds=inputs_embeds, 133 | use_cache=use_cache, 134 | output_attentions=output_attentions, 135 | output_hidden_states=output_hidden_states, 136 | return_dict=return_dict, 137 | ) 138 | 139 | hidden_states = outputs[0] 140 | hidden_dim = hidden_states.size(-1) 141 | shift_labels = labels[..., 1:].contiguous().reshape(-1) 142 | shift_hidden_states = hidden_states[..., :-1, :].contiguous().reshape(-1, hidden_dim) 143 | assert shift_labels.size(0) == shift_hidden_states.size(0) 144 | mask = shift_labels > -1 145 | assert mask.float().sum() > 0 146 | shift_labels = shift_labels[mask] 147 | shift_hidden_states = shift_hidden_states[mask, :] 148 | logits = self.lm_head(shift_hidden_states) 149 | logits = logits.float() 150 | loss_fct = nn.CrossEntropyLoss() 151 | loss = loss_fct(logits, shift_labels) 152 | 153 | 154 | if not return_dict: 155 | output = (logits,) + outputs[1:] 156 | return (loss,) + output if loss is not None else output 157 | 158 | 159 | return CausalLMOutputWithPast( 160 | loss=loss, 161 | logits=logits, 162 | past_key_values=outputs.past_key_values, 163 | hidden_states=outputs.hidden_states, 164 | attentions=outputs.attentions, 165 | ) 166 | 167 | @torch.no_grad() 168 | def generate( 169 | self, 170 | inputs: Optional[torch.Tensor] = None, 171 | speech: Optional[torch.Tensor] = None, 172 | speech_lengths: Optional[torch.Tensor] = None, 173 | speech_chunks: Optional[torch.Tensor] = None, 174 | speech_wav: Optional[torch.FloatTensor] = None, 175 | images: Optional[torch.Tensor] = None, 176 | images_highres: Optional[List[torch.FloatTensor]] = None, 177 | image_sizes: Optional[torch.Tensor] = None, 178 | modalities: Optional[List[str]] = ["image"], 179 | **kwargs, 180 | ) -> Union[GenerateOutput, torch.LongTensor]: 181 | position_ids = kwargs.pop("position_ids", None) 182 | attention_mask = kwargs.pop("attention_mask", None) 183 | if "inputs_embeds" in kwargs: 184 | raise NotImplementedError("`inputs_embeds` is not supported") 185 | 186 | ( 187 | inputs, 188 | position_ids, 189 | attention_mask, 190 | _, 191 | inputs_embeds, 192 | _ 193 | ) = self.prepare_inputs_labels_for_speech_vision_text( 194 | inputs, 195 | position_ids, 196 | attention_mask, 197 | None, 198 | None, 199 | speech, 200 | speech_lengths, 201 | speech_chunks, 202 | speech_wav, 203 | images, 204 | modalities, 205 | image_sizes, 206 | images_highres 207 | ) 208 | 209 | return super().generate( 210 | position_ids=position_ids, 211 | attention_mask=attention_mask, 212 | inputs_embeds=inputs_embeds, 213 | **kwargs 214 | ) 215 | 216 | def prepare_inputs_for_generation(self, input_ids, past_key_values=None, 217 | inputs_embeds=None, **kwargs): 218 | speech = kwargs.pop("speech", None) 219 | speech_lengths = kwargs.pop("speech_lengths", None) 220 | speech_chunks = kwargs.pop("speech_chunks", None) 221 | images = kwargs.pop("images", None) 222 | image_sizes = kwargs.pop("image_sizes", None) 223 | inputs = super().prepare_inputs_for_generation( 224 | input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs 225 | ) 226 | if speech is not None: 227 | inputs['speech'] = speech 228 | inputs['speech_lengths'] = speech_lengths 229 | inputs['speech_chunks'] = speech_chunks 230 | if images is not None: 231 | inputs["images"] = images 232 | if image_sizes is not None: 233 | inputs["image_sizes"] = image_sizes 234 | return inputs 235 | 236 | AutoConfig.register("ola_qwen", OlaConfigQwen) 237 | AutoModelForCausalLM.register(OlaConfigQwen, OlaQwenForCausalLM) 238 | -------------------------------------------------------------------------------- /ola/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .oryx_vit import SigLIPViTAnysizeWrapper 3 | 4 | def build_vision_tower(vision_tower_cfg, **kwargs): 5 | vision_tower = getattr(vision_tower_cfg, 'vision_tower', getattr(vision_tower_cfg, 'mm_vision_tower', None)) 6 | is_absolute_path_exists = os.path.exists(vision_tower) 7 | print(f"Buiding OryxViTWrapper from {vision_tower}...") 8 | # path = vision_tower.split(":")[1] 9 | return SigLIPViTAnysizeWrapper(vision_tower, path=vision_tower, args=vision_tower_cfg, **kwargs) -------------------------------------------------------------------------------- /ola/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | import math 6 | 7 | from .pooler_projector import NormalizedDwPooler 8 | import os 9 | import math 10 | 11 | class IdentityMap(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, x, *args, **kwargs): 16 | return x 17 | 18 | @property 19 | def config(self): 20 | return {"mm_projector_type": 'identity'} 21 | 22 | 23 | class SimpleResBlock(nn.Module): 24 | def __init__(self, channels): 25 | super().__init__() 26 | self.pre_norm = nn.LayerNorm(channels) 27 | 28 | self.proj = nn.Sequential( 29 | nn.Linear(channels, channels), 30 | nn.GELU(), 31 | nn.Linear(channels, channels) 32 | ) 33 | def forward(self, x): 34 | x = self.pre_norm(x) 35 | return x + self.proj(x) 36 | 37 | class OlaMLP(nn.Module): 38 | def __init__(self, in_channels, out_channels, twoview=False): 39 | super().__init__() 40 | 41 | self.proj1 = nn.Linear(in_channels, out_channels) 42 | self.proj2 = nn.Linear(out_channels, out_channels) 43 | self.act = nn.GELU() 44 | self.pooler = NormalizedDwPooler(out_channels) 45 | 46 | embed_std = 1 / math.sqrt(out_channels) 47 | self.image_newline = nn.Parameter( 48 | torch.randn(out_channels) * embed_std 49 | ) 50 | self.image_begin = nn.Parameter( 51 | torch.randn(out_channels) * embed_std 52 | ) 53 | self.image_end = nn.Parameter( 54 | torch.randn(out_channels) * embed_std 55 | ) 56 | 57 | if twoview: 58 | self.image_sep = nn.Parameter( 59 | torch.randn(out_channels) * embed_std 60 | ) 61 | 62 | def forward(self, x, size=(16,16), x2=None, size2=(16, 16), modalities='image'): 63 | 64 | if modalities in ['image', 'text']: 65 | h, w = size 66 | dtype = x.dtype 67 | x = x.reshape(x.shape[0], h, w, -1) 68 | x = self.proj1(x) 69 | x = self.pooler(x, forward_type='2x') 70 | x = self.act(x) 71 | x = self.proj2(x) 72 | 73 | 74 | b, h, w, c = x.shape 75 | x = torch.cat([ 76 | x, 77 | self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype) 78 | ], dim=2) 79 | x = x.reshape(b, -1, c) 80 | 81 | if x2 is not None: 82 | h2, w2 = size2 83 | x2 = x2.reshape(x2.shape[0], h2, w2, -1) 84 | x2 = self.proj1(x2) 85 | x2 = self.pooler(x2, forward_type='2x') 86 | x2 = self.act(x2) 87 | x2 = self.proj2(x2) 88 | 89 | b2, h2, w2, c2 = x2.shape 90 | x2 = torch.cat([ 91 | x2, 92 | self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype) 93 | ], dim=2) 94 | x2 = x2.reshape(b, -1, c) 95 | sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype) 96 | x = torch.cat([x, sep, x2], dim=1) 97 | 98 | begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype) 99 | end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype) 100 | x = torch.cat([begin, x, end], dim=1) 101 | return x 102 | elif modalities in ['video']: 103 | # x2 is the true feature, ignore x 104 | h, w = size 105 | dtype = x.dtype 106 | x = x.reshape(x.shape[0], h, w, -1) 107 | x1 = self.proj1(x) 108 | x1 = self.pooler(x1, forward_type='2x') 109 | x1 = self.proj2(x1).mean() * 0.0 110 | 111 | h2, w2 = size2 112 | x2 = x2.reshape(x2.shape[0], h2, w2, -1) 113 | x2 = self.proj1(x2) 114 | x2 = self.pooler(x2, forward_type='2x') 115 | x2 = self.act(x2) 116 | x2 = self.proj2(x2) 117 | 118 | b2, h2, w2, c = x2.shape 119 | x2 = torch.cat([ 120 | x2, 121 | self.image_newline.reshape(1, 1, 1, c).expand(b2, h2, 1, c).to(dtype) 122 | ], dim=2) 123 | 124 | x2 = x2.reshape(b2, -1, c) 125 | 126 | sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, c).to(dtype) 127 | x2 = torch.cat([x2, sep], dim=1) 128 | 129 | x2 = x2.flatten(0, 1) 130 | 131 | begin = self.image_begin.reshape(1, -1).expand(1, c).to(dtype) 132 | end = self.image_end.reshape(1, -1).expand(1, c).to(dtype) 133 | x2 = torch.cat([begin, x2, end], dim=0) 134 | x2 = x2.unsqueeze(0) 135 | return x2 136 | else: 137 | raise ValueError(f'Unknown modalities: {modalities}') 138 | 139 | def build_vision_projector(config, delay_load=False, **kwargs): 140 | projector_type = getattr(config, 'mm_projector_type', 'linear') 141 | 142 | if projector_type == 'linear': 143 | return nn.Linear(config.mm_hidden_size, config.hidden_size) 144 | 145 | elif projector_type == 'ola_mlp': 146 | return OlaMLP(config.mm_hidden_size, config.hidden_size, twoview=True) 147 | 148 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 149 | if mlp_gelu_match: 150 | mlp_depth = int(mlp_gelu_match.group(1)) 151 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 152 | for _ in range(1, mlp_depth): 153 | modules.append(nn.GELU()) 154 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 155 | return nn.Sequential(*modules) 156 | 157 | mlp_gelu_resnet_match = re.match(r'^mlp(\d+)x_res(\d+)x_gelu$', projector_type) 158 | if mlp_gelu_resnet_match: 159 | mlp_depth = int(mlp_gelu_resnet_match.group(1)) 160 | res_depth = int(mlp_gelu_resnet_match.group(2)) 161 | modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 162 | for _ in range(1, mlp_depth): 163 | modules.append(nn.GELU()) 164 | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) 165 | for _ in range(res_depth): 166 | modules.append(SimpleResBlock(config.hidden_size)) 167 | return nn.Sequential(*modules) 168 | 169 | if projector_type == 'identity': 170 | return IdentityMap() 171 | 172 | raise ValueError(f'Unknown projector type: {projector_type}') 173 | -------------------------------------------------------------------------------- /ola/model/multimodal_projector/pooler_projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | 6 | from transformers.models.clip.modeling_clip import CLIPVisionModel 7 | import os 8 | 9 | class PoolerProjector(nn.Module): 10 | def __init__(self, config, vision_cfg): 11 | super().__init__() 12 | self._config = config 13 | self.hw = vision_cfg.image_size // vision_cfg.patch_size 14 | 15 | self.conv_pool = nn.Conv2d( 16 | config.mm_hidden_size, config.hidden_size, 17 | kernel_size=2, stride=2 18 | ) 19 | 20 | self.proj = nn.Sequential( 21 | nn.GELU(), 22 | nn.Linear(config.hidden_size, config.hidden_size), 23 | ) 24 | 25 | def forward(self, x, *args, **kwargs): 26 | height = width = self.hw 27 | assert height * width == x.shape[1] 28 | x = x.view(x.shape[0], height, width, -1).permute(0, 3, 1, 2) 29 | x = self.conv_pool(x) 30 | x = x.flatten(2).transpose(1, 2) 31 | x = self.proj(x) 32 | return x 33 | 34 | @property 35 | def config(self): 36 | return {"mm_projector_type": 'pooler'} 37 | 38 | 39 | class NormalizedDwPooler(nn.Module): 40 | def __init__(self, dim): 41 | super().__init__() 42 | self.dim = dim 43 | self.predictor = nn.Sequential( 44 | nn.Linear(dim*2, dim), 45 | nn.GELU(), 46 | nn.Linear(dim, dim), 47 | ) 48 | 49 | def forward(self, x, forward_type='2x'): 50 | B, H, W, C = x.shape 51 | 52 | if forward_type == '2x': 53 | new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C) 54 | pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1) 55 | fused_x = torch.cat([new_x, pooled_x], dim=-1) 56 | elif forward_type == '1x': 57 | new_x = x.reshape(B, H, W, 1, C) 58 | fused_x = torch.cat([new_x, new_x], dim=-1) 59 | elif forward_type == '4x': 60 | new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C) 61 | pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1) 62 | fused_x = torch.cat([new_x, pooled_x], dim=-1) 63 | 64 | score = self.predictor(fused_x) 65 | normalized_score = F.softmax(score, dim=-2) 66 | new_x = (new_x * normalized_score).sum(dim=-2) 67 | return new_x 68 | -------------------------------------------------------------------------------- /ola/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/multimodal_resampler/__pycache__/builder.cpython-310.pyc -------------------------------------------------------------------------------- /ola/model/multimodal_resampler/__pycache__/builder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/multimodal_resampler/__pycache__/builder.cpython-38.pyc -------------------------------------------------------------------------------- /ola/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-310.pyc -------------------------------------------------------------------------------- /ola/model/multimodal_resampler/__pycache__/perceiver.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/multimodal_resampler/__pycache__/perceiver.cpython-38.pyc -------------------------------------------------------------------------------- /ola/model/multimodal_resampler/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class IdentityMap(torch.nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | 7 | def forward(self, x, *args, **kwargs): 8 | return x 9 | 10 | @property 11 | def config(self): 12 | return {"mm_resampler_type": None} 13 | 14 | def build_vision_resampler(model_args, delay_load=False, **kwargs): 15 | # import pdb;pdb.set_trace() 16 | resampler_type = getattr(model_args, 'mm_resampler_type', None) 17 | if resampler_type is None: 18 | return IdentityMap() 19 | else: 20 | raise ValueError(f'Unknown resampler type: {resampler_type}') 21 | -------------------------------------------------------------------------------- /ola/model/ola_arch.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | 5 | from .speech_encoder.builder import build_speech_encoder 6 | from .speech_projector.builder import build_speech_projector 7 | from ola.constants import IGNORE_INDEX, SPEECH_TOKEN_INDEX 8 | from ola.utils import lengths_to_padding_mask 9 | 10 | from .multimodal_encoder.builder import build_vision_tower 11 | from .multimodal_resampler.builder import build_vision_resampler 12 | from .multimodal_projector.builder import build_vision_projector 13 | 14 | from ola.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 15 | 16 | class OlaMetaModel: 17 | 18 | def __init__(self, config): 19 | super(OlaMetaModel, self).__init__(config) 20 | 21 | if hasattr(config, "speech_encoder"): 22 | self.speech_encoder = build_speech_encoder(config) 23 | self.speech_projector = build_speech_projector(config) 24 | 25 | if hasattr(config, "mm_vision_tower"): 26 | self.vision_tower = build_vision_tower(config, delay_load=True) 27 | self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower) 28 | self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config) 29 | 30 | def get_speech_encoder(self): 31 | speech_encoder = getattr(self, 'speech_encoder', None) 32 | if type(speech_encoder) is list: 33 | speech_encoder = speech_encoder[0] 34 | return speech_encoder 35 | 36 | def get_vision_tower(self): 37 | vision_tower = getattr(self, 'vision_tower', None) 38 | if type(vision_tower) is list: 39 | vision_tower = vision_tower[0] 40 | return vision_tower 41 | 42 | def initialize_speech_modules(self, model_args, fsdp=None): 43 | self.config.speech_encoder = getattr(model_args, "speech_encoder", None) 44 | self.config.speech_encoder_type = getattr(model_args, "speech_encoder_type", None) 45 | self.config.speech_projector_type = getattr(model_args, 'speech_projector_type', 'linear') 46 | self.config.speech_encoder_ds_rate = getattr(model_args, 'speech_encoder_ds_rate', 5) 47 | self.config.speech_encoder_hidden_size = getattr(model_args, 'speech_encoder_hidden_size', 1280) 48 | self.config.music_encoder = getattr(model_args, 'music_encoder', None) 49 | 50 | if self.get_speech_encoder() is None: 51 | speech_encoder = build_speech_encoder(self.config) 52 | if fsdp is not None and len(fsdp) > 0: 53 | self.speech_encoder = [speech_encoder] 54 | else: 55 | self.speech_encoder = speech_encoder 56 | 57 | if getattr(self, 'speech_projector', None) is None: 58 | self.speech_projector = build_speech_projector(self.config) 59 | else: 60 | # In case it is frozen by LoRA 61 | for p in self.speech_projector.parameters(): 62 | p.requires_grad = True 63 | 64 | if model_args.pretrain_speech_projector is not None: 65 | pretrain_speech_projector_weights = torch.load(model_args.pretrain_speech_projector, map_location='cpu') 66 | def get_w(weights, keyword): 67 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 68 | print('Loading pretrain speech projector weights') 69 | 70 | msg = self.speech_projector.load_state_dict(get_w(pretrain_speech_projector_weights, 'speech_projector'), strict=False) 71 | print(msg) 72 | 73 | def initialize_vision_modules(self, model_args, fsdp=None): 74 | vision_tower = model_args.vision_tower 75 | mm_vision_select_layer = model_args.mm_vision_select_layer 76 | mm_vision_select_feature = model_args.mm_vision_select_feature 77 | pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter 78 | 79 | self.config.mm_vision_tower = vision_tower 80 | 81 | if self.get_vision_tower() is None: 82 | vision_tower = build_vision_tower(model_args) 83 | vision_resampler = build_vision_resampler(model_args, vision_tower=vision_tower) 84 | ## Get the mm_spatial_pool_mode and mm_spatial_pool_stride 85 | for k, v in vision_resampler.config.items(): 86 | setattr(self.config, k, v) 87 | 88 | if fsdp is not None and len(fsdp) > 0: 89 | self.vision_tower = [vision_tower] 90 | self.vision_resampler = [vision_resampler] 91 | else: 92 | self.vision_tower = vision_tower 93 | self.vision_resampler = vision_resampler 94 | else: 95 | if fsdp is not None and len(fsdp) > 0: 96 | vision_resampler = self.vision_resampler[0] 97 | vision_tower = self.vision_tower[0] 98 | else: 99 | vision_resampler = self.vision_resampler 100 | vision_tower = self.vision_tower 101 | vision_tower.load_model() 102 | 103 | # In case it is frozen by LoRA 104 | for p in self.vision_resampler.parameters(): 105 | p.requires_grad = True 106 | 107 | self.config.use_mm_proj = True 108 | self.config.mm_projector_type = getattr(model_args, 'mm_projector_type', 'linear') 109 | self.config.mm_hidden_size = getattr(vision_resampler, 'hidden_size', vision_tower.hidden_size) 110 | 111 | self.config.mm_vision_select_layer = mm_vision_select_layer 112 | self.config.mm_vision_select_feature = mm_vision_select_feature 113 | 114 | if getattr(self, 'mm_projector', None) is None: 115 | self.mm_projector = build_vision_projector(self.config, vision_cfg=vision_tower.config) 116 | else: 117 | for p in self.mm_projector.parameters(): 118 | p.requires_grad = True 119 | 120 | if pretrain_mm_mlp_adapter is not None: 121 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') 122 | def get_w(weights, keyword): 123 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 124 | 125 | self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) 126 | print('Loading pretrain mm projector weights') 127 | incompatible_keys = self.vision_resampler.load_state_dict(get_w(mm_projector_weights, 'vision_resampler'), strict=False) 128 | print(incompatible_keys) 129 | 130 | class OlaMetaForCausalLM(ABC): 131 | 132 | @abstractmethod 133 | def get_model(self): 134 | pass 135 | 136 | def get_speech_encoder(self): 137 | return self.get_model().get_speech_encoder() 138 | 139 | def get_vision_tower(self): 140 | return self.get_model().get_vision_tower() 141 | 142 | def get_speech_projector(self): 143 | return self.get_model().speech_projector 144 | 145 | def encode_speech(self, speech, speech_lengths, speech_wav): 146 | # import pdb; pdb.set_trace() 147 | speech_encoder_type = self.config.speech_encoder_type 148 | speech_encoder = self.get_speech_encoder() 149 | if "whisper" in speech_encoder_type.lower(): 150 | encoder_outs = speech_encoder(speech.permute(0, 2, 1)) 151 | speech_lengths = (speech_lengths + 1) // 2 152 | else: 153 | encoder_outs = speech_encoder(speech.permute(0, 2, 1), raw_wav=speech_wav) 154 | speech_lengths = (speech_lengths + 1) // 2 155 | speech_projector_type = self.config.speech_projector_type 156 | speech_projector = self.get_speech_projector() 157 | if speech_projector_type == "linear": 158 | encoder_outs = speech_projector(encoder_outs) 159 | speech_lengths = speech_lengths // speech_projector.k 160 | else: 161 | raise ValueError(f'Unknown speech projector: {speech_projector_type}') 162 | # speech_features = [encoder_outs[i, :speech_lengths[i]] for i in range(len(encoder_outs))] 163 | return encoder_outs 164 | 165 | def prepare_inputs_labels_for_speech_vision_text( 166 | self, input_ids, position_ids, attention_mask, past_key_values, labels, 167 | speech, speech_lengths, speech_chunks, speech_wav, images, modalities, image_sizes=None, images_highres=None 168 | ): 169 | speech_encoder = self.get_speech_encoder() 170 | vision_tower = self.get_vision_tower() 171 | 172 | if speech_encoder is None or input_ids.shape[1] == 1: 173 | return input_ids, position_ids, attention_mask, past_key_values, None, labels 174 | 175 | if vision_tower is None or input_ids.shape[1] == 1: 176 | return input_ids, position_ids, attention_mask, past_key_values, None, labels 177 | # encode speech 178 | if not isinstance(speech, list): 179 | speech = torch.split(speech, speech_chunks.tolist(), dim=0) 180 | speech_lengths = torch.split(speech_lengths, speech_chunks.tolist(), dim=0) 181 | speech_wav = torch.split(speech_wav, speech_chunks.tolist(), dim=0) 182 | speech_features = [] 183 | for idx in range(len(speech)): 184 | speech_features.append(self.encode_speech(speech[idx], speech_lengths[idx], speech_wav[idx])) 185 | 186 | # encode vision 187 | if isinstance(modalities, str): 188 | modalities = [modalities] 189 | 190 | video_idx_in_batch = [] 191 | for modal in range(len(modalities)): 192 | if 'video' in modalities[modal]: 193 | video_idx_in_batch.append(modal) 194 | 195 | aimg = images[-1] 196 | lowres_img = [] 197 | for idx, img_feat in enumerate(images): 198 | if idx in video_idx_in_batch: 199 | img_feat = aimg.new(1, 3, 128, 128).fill_(0) 200 | lowres_img.append(img_feat) 201 | 202 | lowres_img_features, lowres_img_sizes = self.get_model().get_vision_tower()(lowres_img) 203 | highres_img_features = [] 204 | highres_img_sizes = [] 205 | for idx, img_feat in enumerate(images_highres): 206 | if img_feat.ndim == 5: 207 | img_feat = img_feat.squeeze(1) 208 | highres_img_feature, highres_img_size = self.get_model().get_vision_tower()(img_feat) 209 | highres_img_features.append(highres_img_feature) 210 | highres_img_sizes.append(highres_img_size) 211 | image_features = [] 212 | for idx in range(len(modalities)): 213 | img_feat = self.get_model().mm_projector(lowres_img_features[idx], 214 | lowres_img_sizes[idx], 215 | highres_img_features[idx], 216 | highres_img_sizes[idx], 217 | modalities[idx]) 218 | image_features.append(img_feat.flatten(0, 1)) 219 | 220 | _labels = labels 221 | _position_ids = position_ids 222 | _attention_mask = attention_mask 223 | if attention_mask is None: 224 | attention_mask = torch.ones_like(input_ids, dtype=torch.bool) 225 | else: 226 | attention_mask = attention_mask.bool() 227 | if position_ids is None: 228 | position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) 229 | if labels is None: 230 | labels = torch.full_like(input_ids, IGNORE_INDEX) 231 | 232 | # remove the padding using attention_mask -- FIXME 233 | _input_ids = input_ids 234 | input_ids = [cur_input_ids[cur_attention_mask] for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)] 235 | labels = [cur_labels[cur_attention_mask] for cur_labels, cur_attention_mask in zip(labels, attention_mask)] 236 | 237 | new_input_embeds = [] 238 | new_labels = [] 239 | cur_speech_idx = 0 240 | cur_image_idx = 0 241 | for batch_idx, cur_input_ids in enumerate(input_ids): 242 | 243 | num_speech = (cur_input_ids == SPEECH_TOKEN_INDEX).sum() 244 | num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() 245 | 246 | num_speech_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() + (cur_input_ids == SPEECH_TOKEN_INDEX).sum() 247 | 248 | if num_speech_images == 0: 249 | cur_speech_features = speech_features[cur_speech_idx] 250 | cur_images_features = image_features[cur_image_idx] 251 | cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids) 252 | cur_input_embeds = torch.cat([cur_input_embeds_1, cur_speech_features[0:0], cur_images_features[0:0]], dim=0) 253 | new_input_embeds.append(cur_input_embeds) 254 | new_labels.append(labels[batch_idx]) 255 | cur_speech_idx += 1 256 | cur_image_idx += 1 257 | continue 258 | speech_image_token_indices = [-1] + torch.where((cur_input_ids == SPEECH_TOKEN_INDEX) | (cur_input_ids == IMAGE_TOKEN_INDEX))[0].tolist() + [cur_input_ids.shape[0]] 259 | 260 | cur_input_ids_nospeech_image = [] 261 | cur_labels = labels[batch_idx] 262 | cur_labels_nospeech_image = [] 263 | for i in range(len(speech_image_token_indices) - 1): 264 | cur_input_ids_nospeech_image.append(cur_input_ids[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) 265 | cur_labels_nospeech_image.append(cur_labels[speech_image_token_indices[i]+1:speech_image_token_indices[i+1]]) 266 | split_sizes = [x.shape[0] for x in cur_labels_nospeech_image] 267 | cur_input_embeds = self.get_model().embed_tokens(torch.cat(cur_input_ids_nospeech_image)) 268 | cur_input_embeds_no_speech_image = torch.split(cur_input_embeds, split_sizes, dim=0) 269 | cur_new_input_embeds = [] 270 | cur_new_labels = [] 271 | 272 | for i in range(num_speech_images + 1): 273 | cur_new_input_embeds.append(cur_input_embeds_no_speech_image[i]) 274 | cur_new_labels.append(cur_labels_nospeech_image[i]) 275 | if i < num_speech_images: 276 | if i < num_images: 277 | cur_images_features = image_features[cur_image_idx] 278 | cur_image_idx += 1 279 | cur_new_input_embeds.append(cur_images_features) 280 | cur_new_labels.append(torch.full((cur_images_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) 281 | else: 282 | cur_speech_features = speech_features[cur_speech_idx] 283 | cur_speech_idx += 1 284 | cur_new_input_embeds.append(cur_speech_features) 285 | cur_new_labels.append(torch.full((cur_speech_features.shape[0],), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype)) 286 | 287 | cur_new_input_embeds = [x.to(self.device) for x in cur_new_input_embeds] 288 | 289 | cur_new_input_embeds = torch.cat(cur_new_input_embeds) 290 | cur_new_labels = torch.cat(cur_new_labels) 291 | 292 | if num_images == 0: 293 | cur_new_input_embeds = torch.cat([cur_new_input_embeds, image_features[cur_image_idx][0:0]], dim=0) 294 | cur_image_idx += 1 295 | 296 | if num_speech == 0: 297 | cur_new_input_embeds = torch.cat([cur_new_input_embeds, speech_features[cur_speech_idx][0:0]], dim=0) 298 | cur_speech_idx += 1 299 | 300 | new_input_embeds.append(cur_new_input_embeds) 301 | new_labels.append(cur_new_labels) 302 | 303 | # Truncate sequences to max length as speech features can make the sequence longer 304 | tokenizer_model_max_length = getattr(self.config, 'tokenizer_model_max_length', None) 305 | if tokenizer_model_max_length is not None: 306 | new_input_embeds = [x[:tokenizer_model_max_length] for x in new_input_embeds] 307 | new_labels = [x[:tokenizer_model_max_length] for x in new_labels] 308 | 309 | # Combine them 310 | max_len = max(x.shape[0] for x in new_input_embeds) 311 | batch_size = len(new_input_embeds) 312 | 313 | new_input_embeds_padded = [] 314 | new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device) 315 | attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device) 316 | position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device) 317 | 318 | for i, (cur_new_embed, cur_new_labels) in enumerate(zip(new_input_embeds, new_labels)): 319 | cur_len = cur_new_embed.shape[0] 320 | if getattr(self.config, 'tokenizer_padding_side', 'right') == "left": 321 | new_input_embeds_padded.append(torch.cat(( 322 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device), 323 | cur_new_embed 324 | ), dim=0)) 325 | if cur_len > 0: 326 | new_labels_padded[i, -cur_len:] = cur_new_labels 327 | attention_mask[i, -cur_len:] = True 328 | position_ids[i, -cur_len:] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 329 | else: 330 | new_input_embeds_padded.append(torch.cat(( 331 | cur_new_embed, 332 | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device) 333 | ), dim=0)) 334 | if cur_len > 0: 335 | new_labels_padded[i, :cur_len] = cur_new_labels 336 | attention_mask[i, :cur_len] = True 337 | position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device) 338 | 339 | new_input_embeds = torch.stack(new_input_embeds_padded, dim=0) 340 | 341 | if _labels is None: 342 | new_labels = None 343 | else: 344 | new_labels = new_labels_padded 345 | 346 | if _attention_mask is None: 347 | attention_mask = None 348 | else: 349 | attention_mask = attention_mask.to(dtype=_attention_mask.dtype) 350 | 351 | if _position_ids is None: 352 | position_ids = None 353 | 354 | return None, position_ids, attention_mask, past_key_values, new_input_embeds, new_labels 355 | 356 | def initialize_vision_tokenizer(self, model_args, tokenizer): 357 | if model_args.mm_use_im_patch_token: 358 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 359 | self.resize_token_embeddings(len(tokenizer)) 360 | 361 | if model_args.mm_use_im_start_end: 362 | num_new_tokens = tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 363 | self.resize_token_embeddings(len(tokenizer)) 364 | 365 | if num_new_tokens > 0: 366 | input_embeddings = self.get_input_embeddings().weight.data 367 | output_embeddings = self.get_output_embeddings().weight.data 368 | 369 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( 370 | dim=0, keepdim=True) 371 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( 372 | dim=0, keepdim=True) 373 | 374 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 375 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 376 | 377 | if model_args.tune_mm_mlp_adapter: 378 | for p in self.get_input_embeddings().parameters(): 379 | p.requires_grad = True 380 | for p in self.get_output_embeddings().parameters(): 381 | p.requires_grad = False 382 | 383 | if model_args.pretrain_mm_mlp_adapter: 384 | mm_projector_weights = torch.load(model_args.pretrain_mm_mlp_adapter, map_location='cpu') 385 | embed_tokens_weight = mm_projector_weights['model.embed_tokens.weight'] 386 | assert num_new_tokens == 2 387 | if input_embeddings.shape == embed_tokens_weight.shape: 388 | input_embeddings[-num_new_tokens:] = embed_tokens_weight[-num_new_tokens:] 389 | elif embed_tokens_weight.shape[0] == num_new_tokens: 390 | input_embeddings[-num_new_tokens:] = embed_tokens_weight 391 | else: 392 | raise ValueError(f"Unexpected embed_tokens_weight shape. Pretrained: {embed_tokens_weight.shape}. Current: {input_embeddings.shape}. Numer of new tokens: {num_new_tokens}.") 393 | elif model_args.mm_use_im_patch_token: 394 | if model_args.tune_mm_mlp_adapter: 395 | for p in self.get_input_embeddings().parameters(): 396 | p.requires_grad = False 397 | for p in self.get_output_embeddings().parameters(): 398 | p.requires_grad = False -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/BEATs.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beats 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import LayerNorm 14 | # import torchaudio.compliance.kaldi as ta_kaldi 15 | 16 | from .kaldi import fbank as kaldi_fbank 17 | 18 | from .backbone import ( 19 | TransformerEncoder, 20 | ) 21 | 22 | import logging 23 | from typing import Optional 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | class BEATsConfig: 29 | def __init__(self, cfg=None): 30 | self.input_patch_size: int = -1 # path size of patch embedding 31 | self.embed_dim: int = 512 # patch embedding dimension 32 | self.conv_bias: bool = False # include bias in conv encoder 33 | 34 | self.encoder_layers: int = 12 # num encoder layers in the transformer 35 | self.encoder_embed_dim: int = 768 # encoder embedding dimension 36 | self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN 37 | self.encoder_attention_heads: int = 12 # num encoder attention heads 38 | self.activation_fn: str = "gelu" # activation function to use 39 | 40 | self.layer_wise_gradient_decay_ratio: float = 1.0 # ratio for layer-wise gradient decay 41 | self.layer_norm_first: bool = False # apply layernorm first in the transformer 42 | self.deep_norm: bool = False # apply deep_norm first in the transformer 43 | 44 | # dropouts 45 | self.dropout: float = 0.1 # dropout probability for the transformer 46 | self.attention_dropout: float = 0.1 # dropout probability for attention weights 47 | self.activation_dropout: float = 0.0 # dropout probability after activation in FFN 48 | self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer 49 | self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) 50 | 51 | # positional embeddings 52 | self.conv_pos: int = 128 # number of filters for convolutional positional embeddings 53 | self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding 54 | 55 | # relative position embedding 56 | self.relative_position_embedding: bool = False # apply relative position embedding 57 | self.num_buckets: int = 320 # number of buckets for relative position embedding 58 | self.max_distance: int = 1280 # maximum distance for relative position embedding 59 | self.gru_rel_pos: bool = False # apply gated relative position embedding 60 | 61 | # label predictor 62 | self.finetuned_model: bool = False # whether the model is a fine-tuned model. 63 | self.predictor_dropout: float = 0.1 # dropout probability for the predictor 64 | self.predictor_class: int = 527 # target class number for the predictor 65 | 66 | if cfg is not None: 67 | self.update(cfg) 68 | 69 | def update(self, cfg: dict): 70 | self.__dict__.update(cfg) 71 | 72 | 73 | class BEATs(nn.Module): 74 | def __init__( 75 | self, 76 | cfg: BEATsConfig, 77 | ) -> None: 78 | super().__init__() 79 | logger.info(f"BEATs Config: {cfg.__dict__}") 80 | 81 | self.cfg = cfg 82 | 83 | self.embed = cfg.embed_dim 84 | self.post_extract_proj = ( 85 | nn.Linear(self.embed, cfg.encoder_embed_dim) 86 | if self.embed != cfg.encoder_embed_dim 87 | else None 88 | ) 89 | 90 | self.input_patch_size = cfg.input_patch_size 91 | self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, 92 | bias=cfg.conv_bias) 93 | 94 | self.dropout_input = nn.Dropout(cfg.dropout_input) 95 | 96 | assert not cfg.deep_norm or not cfg.layer_norm_first 97 | self.encoder = TransformerEncoder(cfg) 98 | self.layer_norm = LayerNorm(self.embed) 99 | 100 | if cfg.finetuned_model: 101 | self.predictor_dropout = nn.Dropout(cfg.predictor_dropout) 102 | self.predictor = nn.Linear(cfg.encoder_embed_dim, cfg.predictor_class) 103 | else: 104 | self.predictor = None 105 | 106 | def forward_padding_mask( 107 | self, 108 | features: torch.Tensor, 109 | padding_mask: torch.Tensor, 110 | ) -> torch.Tensor: 111 | extra = padding_mask.size(1) % features.size(1) 112 | if extra > 0: 113 | padding_mask = padding_mask[:, :-extra] 114 | padding_mask = padding_mask.view( 115 | padding_mask.size(0), features.size(1), -1 116 | ) 117 | padding_mask = padding_mask.all(-1) 118 | return padding_mask 119 | 120 | def preprocess( 121 | self, 122 | source: torch.Tensor, 123 | fbank_mean: float = 15.41663, 124 | fbank_std: float = 6.55582, 125 | ) -> torch.Tensor: 126 | fbanks = [] 127 | for waveform in source: 128 | waveform = waveform.unsqueeze(0) * 2 ** 15 129 | fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) 130 | fbanks.append(fbank) 131 | fbank = torch.stack(fbanks, dim=0) 132 | fbank = (fbank - fbank_mean) / (2 * fbank_std) 133 | return fbank 134 | 135 | def extract_features( 136 | self, 137 | source: torch.Tensor, 138 | padding_mask: Optional[torch.Tensor] = None, 139 | fbank_mean: float = 15.41663, 140 | fbank_std: float = 6.55582, 141 | feature_only=False, 142 | ): 143 | fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std).to(torch.float32) 144 | 145 | if padding_mask is not None: 146 | padding_mask = self.forward_padding_mask(fbank, padding_mask) 147 | 148 | fbank = fbank.unsqueeze(1) 149 | features = self.patch_embedding(fbank) 150 | features = features.reshape(features.shape[0], features.shape[1], -1) 151 | features = features.transpose(1, 2) 152 | features = self.layer_norm(features) 153 | 154 | if padding_mask is not None: 155 | padding_mask = self.forward_padding_mask(features, padding_mask) 156 | 157 | if self.post_extract_proj is not None: 158 | features = self.post_extract_proj(features) 159 | 160 | x = self.dropout_input(features) 161 | 162 | x, layer_results = self.encoder( 163 | x, 164 | padding_mask=padding_mask, 165 | ) 166 | 167 | if not feature_only and self.predictor is not None: 168 | x = self.predictor_dropout(x) 169 | logits = self.predictor(x) 170 | 171 | if padding_mask is not None and padding_mask.any(): 172 | logits[padding_mask] = 0 173 | logits = logits.sum(dim=1) 174 | logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(logits) 175 | else: 176 | logits = logits.mean(dim=1) 177 | 178 | lprobs = torch.sigmoid(logits) 179 | 180 | return lprobs, padding_mask 181 | else: 182 | return x, padding_mask -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/Tokenizers.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beats 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.nn import LayerNorm 14 | # import torchaudio.compliance.kaldi as ta_kaldi 15 | 16 | from .kaldi import fbank as kaldi_fbank 17 | 18 | from .backbone import ( 19 | TransformerEncoder, 20 | ) 21 | from .quantizer import ( 22 | NormEMAVectorQuantizer, 23 | ) 24 | 25 | import logging 26 | from typing import Optional 27 | 28 | logger = logging.getLogger(__name__) 29 | 30 | 31 | class TokenizersConfig: 32 | def __init__(self, cfg=None): 33 | self.input_patch_size: int = -1 # path size of patch embedding 34 | self.embed_dim: int = 512 # patch embedding dimension 35 | self.conv_bias: bool = False # include bias in conv encoder 36 | 37 | self.encoder_layers: int = 12 # num encoder layers in the transformer 38 | self.encoder_embed_dim: int = 768 # encoder embedding dimension 39 | self.encoder_ffn_embed_dim: int = 3072 # encoder embedding dimension for FFN 40 | self.encoder_attention_heads: int = 12 # num encoder attention heads 41 | self.activation_fn: str = "gelu" # activation function to use 42 | 43 | self.layer_norm_first: bool = False # apply layernorm first in the transformer 44 | self.deep_norm: bool = False # apply deep_norm first in the transformer 45 | 46 | # dropouts 47 | self.dropout: float = 0.1 # dropout probability for the transformer 48 | self.attention_dropout: float = 0.1 # dropout probability for attention weights 49 | self.activation_dropout: float = 0.0 # dropout probability after activation in FFN 50 | self.encoder_layerdrop: float = 0.0 # probability of dropping a tarnsformer layer 51 | self.dropout_input: float = 0.0 # dropout to apply to the input (after feat extr) 52 | 53 | # positional embeddings 54 | self.conv_pos: int = 128 # number of filters for convolutional positional embeddings 55 | self.conv_pos_groups: int = 16 # number of groups for convolutional positional embedding 56 | 57 | # relative position embedding 58 | self.relative_position_embedding: bool = False # apply relative position embedding 59 | self.num_buckets: int = 320 # number of buckets for relative position embedding 60 | self.max_distance: int = 1280 # maximum distance for relative position embedding 61 | self.gru_rel_pos: bool = False # apply gated relative position embedding 62 | 63 | # quantizer 64 | self.quant_n: int = 1024 # codebook number in quantizer 65 | self.quant_dim: int = 256 # codebook dimension in quantizer 66 | 67 | if cfg is not None: 68 | self.update(cfg) 69 | 70 | def update(self, cfg: dict): 71 | self.__dict__.update(cfg) 72 | 73 | 74 | class Tokenizers(nn.Module): 75 | def __init__( 76 | self, 77 | cfg: TokenizersConfig, 78 | ) -> None: 79 | super().__init__() 80 | logger.info(f"Tokenizers Config: {cfg.__dict__}") 81 | 82 | self.cfg = cfg 83 | 84 | self.embed = cfg.embed_dim 85 | self.post_extract_proj = ( 86 | nn.Linear(self.embed, cfg.encoder_embed_dim) 87 | if self.embed != cfg.encoder_embed_dim 88 | else None 89 | ) 90 | 91 | self.input_patch_size = cfg.input_patch_size 92 | self.patch_embedding = nn.Conv2d(1, self.embed, kernel_size=self.input_patch_size, stride=self.input_patch_size, 93 | bias=cfg.conv_bias) 94 | 95 | self.dropout_input = nn.Dropout(cfg.dropout_input) 96 | 97 | assert not cfg.deep_norm or not cfg.layer_norm_first 98 | self.encoder = TransformerEncoder(cfg) 99 | self.layer_norm = LayerNorm(self.embed) 100 | 101 | self.quantize = NormEMAVectorQuantizer( 102 | n_embed=cfg.quant_n, embedding_dim=cfg.quant_dim, beta=1.0, kmeans_init=True, decay=0.99, 103 | ) 104 | self.quant_n = cfg.quant_n 105 | self.quantize_layer = nn.Sequential( 106 | nn.Linear(cfg.encoder_embed_dim, cfg.encoder_embed_dim), 107 | nn.Tanh(), 108 | nn.Linear(cfg.encoder_embed_dim, cfg.quant_dim) # for quantize 109 | ) 110 | 111 | def forward_padding_mask( 112 | self, 113 | features: torch.Tensor, 114 | padding_mask: torch.Tensor, 115 | ) -> torch.Tensor: 116 | extra = padding_mask.size(1) % features.size(1) 117 | if extra > 0: 118 | padding_mask = padding_mask[:, :-extra] 119 | padding_mask = padding_mask.view( 120 | padding_mask.size(0), features.size(1), -1 121 | ) 122 | padding_mask = padding_mask.all(-1) 123 | return padding_mask 124 | 125 | def preprocess( 126 | self, 127 | source: torch.Tensor, 128 | fbank_mean: float = 15.41663, 129 | fbank_std: float = 6.55582, 130 | ) -> torch.Tensor: 131 | fbanks = [] 132 | for waveform in source: 133 | waveform = waveform.unsqueeze(0) * 2 ** 15 134 | fbank = kaldi_fbank(waveform, num_mel_bins=128, sample_frequency=16000, frame_length=25, frame_shift=10) 135 | fbanks.append(fbank) 136 | fbank = torch.stack(fbanks, dim=0) 137 | fbank = (fbank - fbank_mean) / (2 * fbank_std) 138 | return fbank 139 | 140 | def extract_labels( 141 | self, 142 | source: torch.Tensor, 143 | padding_mask: Optional[torch.Tensor] = None, 144 | fbank_mean: float = 15.41663, 145 | fbank_std: float = 6.55582, 146 | ): 147 | fbank = self.preprocess(source, fbank_mean=fbank_mean, fbank_std=fbank_std) 148 | 149 | if padding_mask is not None: 150 | padding_mask = self.forward_padding_mask(fbank, padding_mask) 151 | 152 | fbank = fbank.unsqueeze(1) 153 | features = self.patch_embedding(fbank) 154 | features = features.reshape(features.shape[0], features.shape[1], -1) 155 | features = features.transpose(1, 2) 156 | features = self.layer_norm(features) 157 | 158 | if padding_mask is not None: 159 | padding_mask = self.forward_padding_mask(features, padding_mask) 160 | 161 | if self.post_extract_proj is not None: 162 | features = self.post_extract_proj(features) 163 | 164 | x = self.dropout_input(features) 165 | 166 | x, layer_results = self.encoder( 167 | x, 168 | padding_mask=padding_mask, 169 | ) 170 | 171 | quantize_input = self.quantize_layer(x) 172 | quantize_feature, embed_loss, embed_ind = self.quantize(quantize_input) 173 | 174 | return embed_ind 175 | -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/speech_encoder/beats/__init__.py -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-310.pyc -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/speech_encoder/beats/__pycache__/BEATs.cpython-38.pyc -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/speech_encoder/beats/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/__pycache__/backbone.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-310.pyc -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/__pycache__/backbone.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/speech_encoder/beats/__pycache__/backbone.cpython-38.pyc -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-310.pyc -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/speech_encoder/beats/__pycache__/kaldi.cpython-38.pyc -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/__pycache__/modules.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/speech_encoder/beats/__pycache__/modules.cpython-310.pyc -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/model/speech_encoder/beats/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/modules.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beats 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on fairseq code bases 7 | # https://github.com/pytorch/fairseq 8 | # -------------------------------------------------------- 9 | 10 | import math 11 | import warnings 12 | import torch 13 | from torch import Tensor, nn 14 | import torch.nn.functional as F 15 | 16 | 17 | class GradMultiply(torch.autograd.Function): 18 | @staticmethod 19 | def forward(ctx, x, scale): 20 | ctx.scale = scale 21 | res = x.new(x) 22 | return res 23 | 24 | @staticmethod 25 | def backward(ctx, grad): 26 | return grad * ctx.scale, None 27 | 28 | 29 | class SamePad(nn.Module): 30 | def __init__(self, kernel_size, causal=False): 31 | super().__init__() 32 | if causal: 33 | self.remove = kernel_size - 1 34 | else: 35 | self.remove = 1 if kernel_size % 2 == 0 else 0 36 | 37 | def forward(self, x): 38 | if self.remove > 0: 39 | x = x[:, :, : -self.remove] 40 | return x 41 | 42 | 43 | class Swish(nn.Module): 44 | def __init__(self): 45 | super(Swish, self).__init__() 46 | self.act = torch.nn.Sigmoid() 47 | 48 | def forward(self, x): 49 | return x * self.act(x) 50 | 51 | 52 | class GLU_Linear(nn.Module): 53 | def __init__(self, input_dim, output_dim, glu_type="sigmoid", bias_in_glu=True): 54 | super(GLU_Linear, self).__init__() 55 | 56 | self.glu_type = glu_type 57 | self.output_dim = output_dim 58 | 59 | if glu_type == "sigmoid": 60 | self.glu_act = torch.nn.Sigmoid() 61 | elif glu_type == "swish": 62 | self.glu_act = Swish() 63 | elif glu_type == "relu": 64 | self.glu_act = torch.nn.ReLU() 65 | elif glu_type == "gelu": 66 | self.glu_act = torch.nn.GELU() 67 | 68 | if bias_in_glu: 69 | self.linear = nn.Linear(input_dim, output_dim * 2, True) 70 | else: 71 | self.linear = nn.Linear(input_dim, output_dim * 2, False) 72 | 73 | def forward(self, x): 74 | # to be consistent with GLU_Linear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case 75 | x = self.linear(x) 76 | 77 | if self.glu_type == "bilinear": 78 | x = (x[:, :, 0:self.output_dim] * x[:, :, self.output_dim:self.output_dim * 2]) 79 | else: 80 | x = (x[:, :, 0:self.output_dim] * self.glu_act(x[:, :, self.output_dim:self.output_dim * 2])) 81 | 82 | return x 83 | 84 | 85 | def gelu_accurate(x): 86 | if not hasattr(gelu_accurate, "_a"): 87 | gelu_accurate._a = math.sqrt(2 / math.pi) 88 | return ( 89 | 0.5 * x * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) 90 | ) 91 | 92 | 93 | def gelu(x: torch.Tensor) -> torch.Tensor: 94 | return torch.nn.functional.gelu(x.float()).type_as(x) 95 | 96 | 97 | def get_activation_fn(activation: str): 98 | """Returns the activation function corresponding to `activation`""" 99 | 100 | if activation == "relu": 101 | return F.relu 102 | elif activation == "gelu": 103 | return gelu 104 | elif activation == "gelu_fast": 105 | warnings.warn( 106 | "--activation-fn=gelu_fast has been renamed to gelu_accurate" 107 | ) 108 | return gelu_accurate 109 | elif activation == "gelu_accurate": 110 | return gelu_accurate 111 | elif activation == "tanh": 112 | return torch.tanh 113 | elif activation == "linear": 114 | return lambda x: x 115 | elif activation == "glu": 116 | return lambda x: x 117 | else: 118 | raise RuntimeError("--activation-fn {} not supported".format(activation)) 119 | 120 | 121 | def quant_noise(module, p, block_size): 122 | """ 123 | Wraps modules and applies quantization noise to the weights for 124 | subsequent quantization with Iterative Product Quantization as 125 | described in "Training with Quantization Noise for Extreme Model Compression" 126 | 127 | Args: 128 | - module: nn.Module 129 | - p: amount of Quantization Noise 130 | - block_size: size of the blocks for subsequent quantization with iPQ 131 | 132 | Remarks: 133 | - Module weights must have the right sizes wrt the block size 134 | - Only Linear, Embedding and Conv2d modules are supported for the moment 135 | - For more detail on how to quantize by blocks with convolutional weights, 136 | see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" 137 | - We implement the simplest form of noise here as stated in the paper 138 | which consists in randomly dropping blocks 139 | """ 140 | 141 | # if no quantization noise, don't register hook 142 | if p <= 0: 143 | return module 144 | 145 | # supported modules 146 | assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) 147 | 148 | # test whether module.weight has the right sizes wrt block_size 149 | is_conv = module.weight.ndim == 4 150 | 151 | # 2D matrix 152 | if not is_conv: 153 | assert ( 154 | module.weight.size(1) % block_size == 0 155 | ), "Input features must be a multiple of block sizes" 156 | 157 | # 4D matrix 158 | else: 159 | # 1x1 convolutions 160 | if module.kernel_size == (1, 1): 161 | assert ( 162 | module.in_channels % block_size == 0 163 | ), "Input channels must be a multiple of block sizes" 164 | # regular convolutions 165 | else: 166 | k = module.kernel_size[0] * module.kernel_size[1] 167 | assert k % block_size == 0, "Kernel size must be a multiple of block size" 168 | 169 | def _forward_pre_hook(mod, input): 170 | # no noise for evaluation 171 | if mod.training: 172 | if not is_conv: 173 | # gather weight and sizes 174 | weight = mod.weight 175 | in_features = weight.size(1) 176 | out_features = weight.size(0) 177 | 178 | # split weight matrix into blocks and randomly drop selected blocks 179 | mask = torch.zeros( 180 | in_features // block_size * out_features, device=weight.device 181 | ) 182 | mask.bernoulli_(p) 183 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) 184 | 185 | else: 186 | # gather weight and sizes 187 | weight = mod.weight 188 | in_channels = mod.in_channels 189 | out_channels = mod.out_channels 190 | 191 | # split weight matrix into blocks and randomly drop selected blocks 192 | if mod.kernel_size == (1, 1): 193 | mask = torch.zeros( 194 | int(in_channels // block_size * out_channels), 195 | device=weight.device, 196 | ) 197 | mask.bernoulli_(p) 198 | mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) 199 | else: 200 | mask = torch.zeros( 201 | weight.size(0), weight.size(1), device=weight.device 202 | ) 203 | mask.bernoulli_(p) 204 | mask = ( 205 | mask.unsqueeze(2) 206 | .unsqueeze(3) 207 | .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) 208 | ) 209 | 210 | # scale weights and apply mask 211 | mask = mask.to( 212 | torch.bool 213 | ) # x.bool() is not currently supported in TorchScript 214 | s = 1 / (1 - p) 215 | mod.weight.data = s * weight.masked_fill(mask, 0) 216 | 217 | module.register_forward_pre_hook(_forward_pre_hook) 218 | return module 219 | -------------------------------------------------------------------------------- /ola/model/speech_encoder/beats/quantizer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # BEATs: Audio Pre-Training with Acoustic Tokenizers (https://arxiv.org/abs/2212.09058) 3 | # Github source: https://github.com/microsoft/unilm/tree/master/beats 4 | # Copyright (c) 2022 Microsoft 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # Based on VQGAN code bases 7 | # https://github.com/CompVis/taming-transformers 8 | # --------------------------------------------------------' 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | import torch.distributed as distributed 14 | 15 | try: 16 | from einops import rearrange, repeat 17 | except ImportError: 18 | pass 19 | 20 | 21 | def l2norm(t): 22 | return F.normalize(t, p=2, dim=-1) 23 | 24 | 25 | def ema_inplace(moving_avg, new, decay): 26 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 27 | 28 | 29 | def sample_vectors(samples, num): 30 | num_samples, device = samples.shape[0], samples.device 31 | 32 | if num_samples >= num: 33 | indices = torch.randperm(num_samples, device=device)[:num] 34 | else: 35 | indices = torch.randint(0, num_samples, (num,), device=device) 36 | 37 | return samples[indices] 38 | 39 | 40 | def kmeans(samples, num_clusters, num_iters=10, use_cosine_sim=False): 41 | dim, dtype, device = samples.shape[-1], samples.dtype, samples.device 42 | 43 | means = sample_vectors(samples, num_clusters) 44 | 45 | for _ in range(num_iters): 46 | if use_cosine_sim: 47 | dists = samples @ means.t() 48 | else: 49 | diffs = rearrange(samples, 'n d -> n () d') \ 50 | - rearrange(means, 'c d -> () c d') 51 | dists = -(diffs ** 2).sum(dim=-1) 52 | 53 | buckets = dists.max(dim=-1).indices 54 | bins = torch.bincount(buckets, minlength=num_clusters) 55 | zero_mask = bins == 0 56 | bins_min_clamped = bins.masked_fill(zero_mask, 1) 57 | 58 | new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) 59 | new_means.scatter_add_(0, repeat(buckets, 'n -> n d', d=dim), samples) 60 | new_means = new_means / bins_min_clamped[..., None] 61 | 62 | if use_cosine_sim: 63 | new_means = l2norm(new_means) 64 | 65 | means = torch.where(zero_mask[..., None], means, new_means) 66 | 67 | return means, bins 68 | 69 | 70 | class EmbeddingEMA(nn.Module): 71 | def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5, kmeans_init=True, codebook_init_path=''): 72 | super().__init__() 73 | self.num_tokens = num_tokens 74 | self.codebook_dim = codebook_dim 75 | self.decay = decay 76 | self.eps = eps 77 | if codebook_init_path == '': 78 | if not kmeans_init: 79 | weight = torch.randn(num_tokens, codebook_dim) 80 | weight = l2norm(weight) 81 | else: 82 | weight = torch.zeros(num_tokens, codebook_dim) 83 | self.register_buffer('initted', torch.Tensor([not kmeans_init])) 84 | else: 85 | print(f"load init codebook weight from {codebook_init_path}") 86 | codebook_ckpt_weight = torch.load(codebook_init_path, map_location='cpu') 87 | weight = codebook_ckpt_weight.clone() 88 | self.register_buffer('initted', torch.Tensor([True])) 89 | 90 | self.weight = nn.Parameter(weight, requires_grad=False) 91 | self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad=False) 92 | self.embed_avg = nn.Parameter(weight.clone(), requires_grad=False) 93 | # self.register_buffer('initted', torch.Tensor([not kmeans_init])) 94 | self.update = True 95 | 96 | @torch.jit.ignore 97 | def init_embed_(self, data): 98 | if self.initted: 99 | return 100 | print("Performing Kemans init for codebook") 101 | embed, cluster_size = kmeans(data, self.num_tokens, 10, use_cosine_sim=True) 102 | self.weight.data.copy_(embed) 103 | self.cluster_size.data.copy_(cluster_size) 104 | self.initted.data.copy_(torch.Tensor([True])) 105 | 106 | def forward(self, embed_id): 107 | return F.embedding(embed_id, self.weight) 108 | 109 | def cluster_size_ema_update(self, new_cluster_size): 110 | self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) 111 | 112 | def embed_avg_ema_update(self, new_embed_avg): 113 | self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) 114 | 115 | def weight_update(self, num_tokens): 116 | n = self.cluster_size.sum() 117 | smoothed_cluster_size = ( 118 | (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n 119 | ) 120 | # normalize embedding average with smoothed cluster size 121 | embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) 122 | # embed_normalized = l2norm(self.embed_avg / smoothed_cluster_size.unsqueeze(1)) 123 | self.weight.data.copy_(embed_normalized) 124 | 125 | 126 | def norm_ema_inplace(moving_avg, new, decay): 127 | moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) 128 | moving_avg.data.copy_(l2norm(moving_avg.data)) 129 | 130 | 131 | class NormEMAVectorQuantizer(nn.Module): 132 | def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, 133 | statistic_code_usage=True, kmeans_init=False, codebook_init_path=''): 134 | super().__init__() 135 | self.codebook_dim = embedding_dim 136 | self.num_tokens = n_embed 137 | self.beta = beta 138 | self.decay = decay 139 | 140 | # learnable = True if orthogonal_reg_weight > 0 else False 141 | self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps, kmeans_init, codebook_init_path) 142 | 143 | self.statistic_code_usage = statistic_code_usage 144 | if statistic_code_usage: 145 | self.register_buffer('cluster_size', torch.zeros(n_embed)) 146 | if distributed.is_available() and distributed.is_initialized(): 147 | print("ddp is enable, so use ddp_reduce to sync the statistic_code_usage for each gpu!") 148 | self.all_reduce_fn = distributed.all_reduce 149 | else: 150 | self.all_reduce_fn = nn.Identity() 151 | 152 | def reset_cluster_size(self, device): 153 | if self.statistic_code_usage: 154 | self.register_buffer('cluster_size', torch.zeros(self.num_tokens)) 155 | self.cluster_size = self.cluster_size.to(device) 156 | 157 | def forward(self, z): 158 | # reshape z -> (batch, height, width, channel) and flatten 159 | # z, 'b c h w -> b h w c' 160 | # z = rearrange(z, 'b c h w -> b h w c') 161 | # z = z.transpose(1, 2) 162 | z = l2norm(z) 163 | z_flattened = z.reshape(-1, self.codebook_dim) 164 | 165 | self.embedding.init_embed_(z_flattened) 166 | 167 | d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ 168 | self.embedding.weight.pow(2).sum(dim=1) - 2 * \ 169 | torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' 170 | 171 | encoding_indices = torch.argmin(d, dim=1) 172 | 173 | z_q = self.embedding(encoding_indices).view(z.shape) 174 | 175 | encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) 176 | 177 | if not self.training: 178 | with torch.no_grad(): 179 | cluster_size = encodings.sum(0) 180 | self.all_reduce_fn(cluster_size) 181 | ema_inplace(self.cluster_size, cluster_size, self.decay) 182 | 183 | if self.training and self.embedding.update: 184 | # EMA cluster size 185 | 186 | bins = encodings.sum(0) 187 | self.all_reduce_fn(bins) 188 | 189 | # self.embedding.cluster_size_ema_update(bins) 190 | ema_inplace(self.cluster_size, bins, self.decay) 191 | 192 | zero_mask = (bins == 0) 193 | bins = bins.masked_fill(zero_mask, 1.) 194 | 195 | embed_sum = z_flattened.t() @ encodings 196 | self.all_reduce_fn(embed_sum) 197 | 198 | embed_normalized = (embed_sum / bins.unsqueeze(0)).t() 199 | embed_normalized = l2norm(embed_normalized) 200 | 201 | embed_normalized = torch.where(zero_mask[..., None], self.embedding.weight, 202 | embed_normalized) 203 | norm_ema_inplace(self.embedding.weight, embed_normalized, self.decay) 204 | 205 | # compute loss for embedding 206 | loss = self.beta * F.mse_loss(z_q.detach(), z) 207 | 208 | # preserve gradients 209 | z_q = z + (z_q - z).detach() 210 | 211 | # reshape back to match original input shape 212 | # z_q, 'b h w c -> b c h w' 213 | # z_q = rearrange(z_q, 'b h w c -> b c h w') 214 | # z_q = z_q.transpose(1, 2) 215 | return z_q, loss, encoding_indices -------------------------------------------------------------------------------- /ola/model/speech_encoder/builder.py: -------------------------------------------------------------------------------- 1 | from .speech_encoder import WhisperWrappedEncoder, DualWrappedEncoder 2 | import torch.nn as nn 3 | 4 | def build_speech_encoder(config): 5 | speech_encoder_type = getattr(config, 'speech_encoder_type', None) 6 | if "whisper" in speech_encoder_type.lower(): 7 | return WhisperWrappedEncoder.load(config) 8 | elif "dual" in speech_encoder_type.lower(): 9 | return DualWrappedEncoder(config) 10 | elif "none" in speech_encoder_type.lower(): 11 | return None 12 | 13 | raise ValueError(f'Unknown speech encoder: {speech_encoder_type}') 14 | -------------------------------------------------------------------------------- /ola/model/speech_encoder/speech_encoder.py: -------------------------------------------------------------------------------- 1 | import types 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from transformers import WhisperFeatureExtractor 6 | import whisper 7 | 8 | from ola.model.speech_encoder.beats.BEATs import BEATsConfig, BEATs 9 | 10 | class WhisperWrappedEncoder: 11 | 12 | @classmethod 13 | def load(cls, model_config): 14 | 15 | def replace_layer_norm(module): 16 | from whisper.model import LayerNorm 17 | for name, child in module.named_children(): 18 | if isinstance(child, LayerNorm): 19 | old_params = child.state_dict() 20 | new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) 21 | new_layer_norm.load_state_dict(old_params) 22 | setattr(module, name, new_layer_norm) 23 | else: 24 | replace_layer_norm(child) 25 | 26 | encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder 27 | replace_layer_norm(encoder) 28 | return encoder 29 | 30 | class DualWrappedEncoder(nn.Module): 31 | def __init__(self, config): 32 | super().__init__() 33 | self.config = config 34 | self.whisper_model = self.load_whisper(config) 35 | self.beats_model = self.load_beats(config) 36 | 37 | def load_whisper(cls, model_config): 38 | 39 | def replace_layer_norm(module): 40 | from whisper.model import LayerNorm 41 | for name, child in module.named_children(): 42 | if isinstance(child, LayerNorm): 43 | old_params = child.state_dict() 44 | new_layer_norm = nn.LayerNorm(child.normalized_shape, eps=child.eps, elementwise_affine=child.elementwise_affine) 45 | new_layer_norm.load_state_dict(old_params) 46 | setattr(module, name, new_layer_norm) 47 | else: 48 | replace_layer_norm(child) 49 | 50 | encoder = whisper.load_model(name=model_config.speech_encoder, device='cpu').encoder 51 | replace_layer_norm(encoder) 52 | return encoder 53 | 54 | def load_beats(cls, model_config): 55 | beats_path = model_config.music_encoder 56 | print("Loading BEATs Model") 57 | beats_ckpt = torch.load(beats_path, map_location='cpu') 58 | beats_cfg = BEATsConfig(beats_ckpt['cfg']) 59 | beats = BEATs(beats_cfg) 60 | beats.load_state_dict(beats_ckpt['model']) 61 | return beats 62 | 63 | def forward(self, x, raw_wav=None, audio_padding_mask=None): 64 | with torch.no_grad(): 65 | self.beats_model = self.beats_model.float() 66 | speech_embeds = self.whisper_model(x) 67 | audio_embeds, _ = self.beats_model.extract_features(raw_wav.float(), padding_mask=audio_padding_mask, feature_only=True) 68 | if audio_embeds.size(1) < speech_embeds.size(1): 69 | audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) 70 | elif audio_embeds.size(1) > speech_embeds.size(1): 71 | speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) 72 | speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) 73 | speech_embeds = speech_embeds.to(torch.bfloat16) 74 | return speech_embeds -------------------------------------------------------------------------------- /ola/model/speech_projector/builder.py: -------------------------------------------------------------------------------- 1 | from .speech_projector import EncoderProjectorConcat 2 | 3 | 4 | def build_speech_projector(config): 5 | projector_type = getattr(config, 'speech_projector_type', 'linear') 6 | if projector_type == 'linear': 7 | return EncoderProjectorConcat(config) 8 | elif projector_type == 'none': 9 | return None 10 | 11 | raise ValueError(f'Unknown projector type: {projector_type}') 12 | -------------------------------------------------------------------------------- /ola/model/speech_projector/speech_projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | class EncoderProjectorConcat(nn.Module): 6 | def __init__(self, config): 7 | super().__init__() 8 | self.k = config.speech_encoder_ds_rate 9 | self.encoder_dim = config.speech_encoder_hidden_size 10 | self.llm_dim = config.hidden_size 11 | self.linear1 = nn.Linear(self.encoder_dim * self.k, 2048) 12 | self.relu = nn.ReLU() 13 | self.linear2 = nn.Linear(2048, config.hidden_size) 14 | 15 | embed_std = 1 / math.sqrt(config.hidden_size) 16 | self.speech_newline = nn.Parameter( 17 | torch.randn(config.hidden_size) * embed_std 18 | ) 19 | self.speech_begin = nn.Parameter( 20 | torch.randn(config.hidden_size) * embed_std 21 | ) 22 | self.speech_end = nn.Parameter( 23 | torch.randn(config.hidden_size) * embed_std 24 | ) 25 | 26 | def forward(self, x): 27 | batch_size, seq_len, dim = x.size() 28 | num_frames_to_discard = seq_len % self.k 29 | if num_frames_to_discard > 0: 30 | x = x[:, :-num_frames_to_discard, :] 31 | seq_len = x.size(1) 32 | 33 | x = x.contiguous() 34 | x = x.view(batch_size, seq_len // self.k, dim * self.k) 35 | x = self.linear1(x) 36 | x = self.relu(x) 37 | x = self.linear2(x) 38 | x = torch.cat([ 39 | x, 40 | self.speech_newline.reshape(1, 1, -1).expand(batch_size, 1, -1).to(x.dtype) 41 | ], dim=1) 42 | begin = self.speech_begin.reshape(1, -1).to(x.dtype) 43 | end = self.speech_end.reshape(1, -1).to(x.dtype) 44 | x = x.flatten(0, 1) 45 | x = torch.cat([begin, x, end], dim=0) 46 | # x = x.flatten(0, 1) 47 | return x -------------------------------------------------------------------------------- /ola/serve/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/ola/serve/__init__.py -------------------------------------------------------------------------------- /ola/serve/controller.py: -------------------------------------------------------------------------------- 1 | """ 2 | A controller manages distributed workers. 3 | It sends worker addresses to clients. 4 | """ 5 | import argparse 6 | import asyncio 7 | import dataclasses 8 | from enum import Enum, auto 9 | import json 10 | import logging 11 | import time 12 | from typing import List, Union 13 | import threading 14 | 15 | from fastapi import FastAPI, Request 16 | from fastapi.responses import StreamingResponse 17 | import numpy as np 18 | import requests 19 | import uvicorn 20 | 21 | from omni_speech.constants import CONTROLLER_HEART_BEAT_EXPIRATION 22 | from omni_speech.utils import build_logger, server_error_msg 23 | 24 | 25 | logger = build_logger("controller", "controller.log") 26 | 27 | 28 | class DispatchMethod(Enum): 29 | LOTTERY = auto() 30 | SHORTEST_QUEUE = auto() 31 | 32 | @classmethod 33 | def from_str(cls, name): 34 | if name == "lottery": 35 | return cls.LOTTERY 36 | elif name == "shortest_queue": 37 | return cls.SHORTEST_QUEUE 38 | else: 39 | raise ValueError(f"Invalid dispatch method") 40 | 41 | 42 | @dataclasses.dataclass 43 | class WorkerInfo: 44 | model_names: List[str] 45 | speed: int 46 | queue_length: int 47 | check_heart_beat: bool 48 | last_heart_beat: str 49 | 50 | 51 | def heart_beat_controller(controller): 52 | while True: 53 | time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION) 54 | controller.remove_stable_workers_by_expiration() 55 | 56 | 57 | class Controller: 58 | def __init__(self, dispatch_method: str): 59 | # Dict[str -> WorkerInfo] 60 | self.worker_info = {} 61 | self.dispatch_method = DispatchMethod.from_str(dispatch_method) 62 | 63 | self.heart_beat_thread = threading.Thread( 64 | target=heart_beat_controller, args=(self,), daemon=True) 65 | self.heart_beat_thread.start() 66 | 67 | logger.info("Init controller") 68 | 69 | def register_worker(self, worker_name: str, check_heart_beat: bool, 70 | worker_status: dict): 71 | if worker_name not in self.worker_info: 72 | logger.info(f"Register a new worker: {worker_name}") 73 | else: 74 | logger.info(f"Register an existing worker: {worker_name}") 75 | 76 | if not worker_status: 77 | worker_status = self.get_worker_status(worker_name) 78 | if not worker_status: 79 | return False 80 | 81 | self.worker_info[worker_name] = WorkerInfo( 82 | worker_status["model_names"], worker_status["speed"], worker_status["queue_length"], 83 | check_heart_beat, time.time()) 84 | 85 | logger.info(f"Register done: {worker_name}, {worker_status}") 86 | return True 87 | 88 | def get_worker_status(self, worker_name: str): 89 | try: 90 | r = requests.post(worker_name + "/worker_get_status", timeout=5) 91 | except requests.exceptions.RequestException as e: 92 | logger.error(f"Get status fails: {worker_name}, {e}") 93 | return None 94 | 95 | if r.status_code != 200: 96 | logger.error(f"Get status fails: {worker_name}, {r}") 97 | return None 98 | 99 | return r.json() 100 | 101 | def remove_worker(self, worker_name: str): 102 | del self.worker_info[worker_name] 103 | 104 | def refresh_all_workers(self): 105 | old_info = dict(self.worker_info) 106 | self.worker_info = {} 107 | 108 | for w_name, w_info in old_info.items(): 109 | if not self.register_worker(w_name, w_info.check_heart_beat, None): 110 | logger.info(f"Remove stale worker: {w_name}") 111 | 112 | def list_models(self): 113 | model_names = set() 114 | 115 | for w_name, w_info in self.worker_info.items(): 116 | model_names.update(w_info.model_names) 117 | 118 | return list(model_names) 119 | 120 | def get_worker_address(self, model_name: str): 121 | if self.dispatch_method == DispatchMethod.LOTTERY: 122 | worker_names = [] 123 | worker_speeds = [] 124 | for w_name, w_info in self.worker_info.items(): 125 | if model_name in w_info.model_names: 126 | worker_names.append(w_name) 127 | worker_speeds.append(w_info.speed) 128 | worker_speeds = np.array(worker_speeds, dtype=np.float32) 129 | norm = np.sum(worker_speeds) 130 | if norm < 1e-4: 131 | return "" 132 | worker_speeds = worker_speeds / norm 133 | if True: # Directly return address 134 | pt = np.random.choice(np.arange(len(worker_names)), 135 | p=worker_speeds) 136 | worker_name = worker_names[pt] 137 | return worker_name 138 | 139 | # Check status before returning 140 | while True: 141 | pt = np.random.choice(np.arange(len(worker_names)), 142 | p=worker_speeds) 143 | worker_name = worker_names[pt] 144 | 145 | if self.get_worker_status(worker_name): 146 | break 147 | else: 148 | self.remove_worker(worker_name) 149 | worker_speeds[pt] = 0 150 | norm = np.sum(worker_speeds) 151 | if norm < 1e-4: 152 | return "" 153 | worker_speeds = worker_speeds / norm 154 | continue 155 | return worker_name 156 | elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE: 157 | worker_names = [] 158 | worker_qlen = [] 159 | for w_name, w_info in self.worker_info.items(): 160 | if model_name in w_info.model_names: 161 | worker_names.append(w_name) 162 | worker_qlen.append(w_info.queue_length / w_info.speed) 163 | if len(worker_names) == 0: 164 | return "" 165 | min_index = np.argmin(worker_qlen) 166 | w_name = worker_names[min_index] 167 | self.worker_info[w_name].queue_length += 1 168 | logger.info(f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}") 169 | return w_name 170 | else: 171 | raise ValueError(f"Invalid dispatch method: {self.dispatch_method}") 172 | 173 | def receive_heart_beat(self, worker_name: str, queue_length: int): 174 | if worker_name not in self.worker_info: 175 | logger.info(f"Receive unknown heart beat. {worker_name}") 176 | return False 177 | 178 | self.worker_info[worker_name].queue_length = queue_length 179 | self.worker_info[worker_name].last_heart_beat = time.time() 180 | logger.info(f"Receive heart beat. {worker_name}") 181 | return True 182 | 183 | def remove_stable_workers_by_expiration(self): 184 | expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION 185 | to_delete = [] 186 | for worker_name, w_info in self.worker_info.items(): 187 | if w_info.check_heart_beat and w_info.last_heart_beat < expire: 188 | to_delete.append(worker_name) 189 | 190 | for worker_name in to_delete: 191 | self.remove_worker(worker_name) 192 | 193 | def worker_api_generate_stream(self, params): 194 | worker_addr = self.get_worker_address(params["model"]) 195 | if not worker_addr: 196 | logger.info(f"no worker: {params['model']}") 197 | ret = { 198 | "text": server_error_msg, 199 | "error_code": 2, 200 | } 201 | yield json.dumps(ret).encode() + b"\0" 202 | 203 | try: 204 | response = requests.post(worker_addr + "/worker_generate_stream", 205 | json=params, stream=True, timeout=5) 206 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 207 | if chunk: 208 | yield chunk + b"\0" 209 | except requests.exceptions.RequestException as e: 210 | logger.info(f"worker timeout: {worker_addr}") 211 | ret = { 212 | "text": server_error_msg, 213 | "error_code": 3, 214 | } 215 | yield json.dumps(ret).encode() + b"\0" 216 | 217 | 218 | # Let the controller act as a worker to achieve hierarchical 219 | # management. This can be used to connect isolated sub networks. 220 | def worker_api_get_status(self): 221 | model_names = set() 222 | speed = 0 223 | queue_length = 0 224 | 225 | for w_name in self.worker_info: 226 | worker_status = self.get_worker_status(w_name) 227 | if worker_status is not None: 228 | model_names.update(worker_status["model_names"]) 229 | speed += worker_status["speed"] 230 | queue_length += worker_status["queue_length"] 231 | 232 | return { 233 | "model_names": list(model_names), 234 | "speed": speed, 235 | "queue_length": queue_length, 236 | } 237 | 238 | 239 | app = FastAPI() 240 | 241 | 242 | @app.post("/register_worker") 243 | async def register_worker(request: Request): 244 | data = await request.json() 245 | controller.register_worker( 246 | data["worker_name"], data["check_heart_beat"], 247 | data.get("worker_status", None)) 248 | 249 | 250 | @app.post("/refresh_all_workers") 251 | async def refresh_all_workers(): 252 | models = controller.refresh_all_workers() 253 | 254 | 255 | @app.post("/list_models") 256 | async def list_models(): 257 | models = controller.list_models() 258 | return {"models": models} 259 | 260 | 261 | @app.post("/get_worker_address") 262 | async def get_worker_address(request: Request): 263 | data = await request.json() 264 | addr = controller.get_worker_address(data["model"]) 265 | return {"address": addr} 266 | 267 | 268 | @app.post("/receive_heart_beat") 269 | async def receive_heart_beat(request: Request): 270 | data = await request.json() 271 | exist = controller.receive_heart_beat( 272 | data["worker_name"], data["queue_length"]) 273 | return {"exist": exist} 274 | 275 | 276 | @app.post("/worker_generate_stream") 277 | async def worker_api_generate_stream(request: Request): 278 | params = await request.json() 279 | generator = controller.worker_api_generate_stream(params) 280 | return StreamingResponse(generator) 281 | 282 | 283 | @app.post("/worker_get_status") 284 | async def worker_api_get_status(request: Request): 285 | return controller.worker_api_get_status() 286 | 287 | 288 | if __name__ == "__main__": 289 | parser = argparse.ArgumentParser() 290 | parser.add_argument("--host", type=str, default="localhost") 291 | parser.add_argument("--port", type=int, default=21001) 292 | parser.add_argument("--dispatch-method", type=str, choices=[ 293 | "lottery", "shortest_queue"], default="shortest_queue") 294 | args = parser.parse_args() 295 | logger.info(f"args: {args}") 296 | 297 | controller = Controller(args.dispatch_method) 298 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") -------------------------------------------------------------------------------- /ola/serve/gradio_web_server.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import os 5 | import time 6 | import torch 7 | import torchaudio 8 | 9 | import gradio as gr 10 | import numpy as np 11 | import requests 12 | import soundfile as sf 13 | 14 | from omni_speech.conversation import default_conversation, conv_templates 15 | from omni_speech.constants import LOGDIR 16 | from omni_speech.utils import build_logger, server_error_msg 17 | from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder 18 | 19 | 20 | logger = build_logger("gradio_web_server", "gradio_web_server.log") 21 | 22 | vocoder = None 23 | 24 | headers = {"User-Agent": "LLaMA-Omni Client"} 25 | 26 | no_change_btn = gr.Button() 27 | enable_btn = gr.Button(interactive=True) 28 | disable_btn = gr.Button(interactive=False) 29 | 30 | 31 | def get_conv_log_filename(): 32 | t = datetime.datetime.now() 33 | name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") 34 | return name 35 | 36 | 37 | def get_model_list(): 38 | ret = requests.post(args.controller_url + "/refresh_all_workers") 39 | assert ret.status_code == 200 40 | ret = requests.post(args.controller_url + "/list_models") 41 | models = ret.json()["models"] 42 | logger.info(f"Models: {models}") 43 | return models 44 | 45 | 46 | get_window_url_params = """ 47 | function() { 48 | const params = new URLSearchParams(window.location.search); 49 | url_params = Object.fromEntries(params); 50 | console.log(url_params); 51 | return url_params; 52 | } 53 | """ 54 | 55 | 56 | def load_demo(url_params, request: gr.Request): 57 | logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") 58 | 59 | dropdown_update = gr.Dropdown(visible=True) 60 | if "model" in url_params: 61 | model = url_params["model"] 62 | if model in models: 63 | dropdown_update = gr.Dropdown(value=model, visible=True) 64 | 65 | state = default_conversation.copy() 66 | return state, dropdown_update 67 | 68 | 69 | def load_demo_refresh_model_list(request: gr.Request): 70 | logger.info(f"load_demo. ip: {request.client.host}") 71 | models = get_model_list() 72 | state = default_conversation.copy() 73 | dropdown_update = gr.Dropdown( 74 | choices=models, 75 | value=models[0] if len(models) > 0 else "" 76 | ) 77 | return state, dropdown_update 78 | 79 | 80 | def clear_history(request: gr.Request): 81 | logger.info(f"clear_history. ip: {request.client.host}") 82 | state = default_conversation.copy() 83 | return (state, None, "", "", None) 84 | 85 | 86 | def add_speech(state, speech, request: gr.Request): 87 | text = "Please directly answer the questions in the user's speech." 88 | text = '\n' + text 89 | text = (text, speech) 90 | state = default_conversation.copy() 91 | state.append_message(state.roles[0], text) 92 | state.append_message(state.roles[1], None) 93 | state.skip_next = False 94 | return (state) 95 | 96 | 97 | def http_bot(state, model_selector, temperature, top_p, max_new_tokens, chunk_size, request: gr.Request): 98 | logger.info(f"http_bot. ip: {request.client.host}") 99 | start_tstamp = time.time() 100 | model_name = model_selector 101 | 102 | if state.skip_next: 103 | # This generate call is skipped due to invalid inputs 104 | yield (state, "", "", None) 105 | return 106 | 107 | if len(state.messages) == state.offset + 2: 108 | # First round of conversation 109 | template_name = "llama_3" 110 | new_state = conv_templates[template_name].copy() 111 | new_state.append_message(new_state.roles[0], state.messages[-2][1]) 112 | new_state.append_message(new_state.roles[1], None) 113 | state = new_state 114 | 115 | # Query worker address 116 | controller_url = args.controller_url 117 | ret = requests.post(controller_url + "/get_worker_address", 118 | json={"model": model_name}) 119 | worker_addr = ret.json()["address"] 120 | logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") 121 | 122 | # No available worker 123 | if worker_addr == "": 124 | state.messages[-1][-1] = server_error_msg 125 | yield (state, "", "", None) 126 | return 127 | 128 | # Construct prompt 129 | prompt = state.get_prompt() 130 | 131 | sr, audio = state.messages[0][1][1] 132 | resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000) 133 | audio = torch.tensor(audio.astype(np.float32)).unsqueeze(0) 134 | audio = resampler(audio).squeeze(0).numpy() 135 | audio /= 32768.0 136 | audio = audio.tolist() 137 | # Make requests 138 | pload = { 139 | "model": model_name, 140 | "prompt": prompt, 141 | "temperature": float(temperature), 142 | "top_p": float(top_p), 143 | "max_new_tokens": min(int(max_new_tokens), 1500), 144 | "stop": state.sep2, 145 | "audio": audio, 146 | } 147 | 148 | yield (state, "", "", None) 149 | 150 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 151 | 152 | try: 153 | # Stream output 154 | response = requests.post(worker_addr + "/worker_generate_stream", 155 | headers=headers, json=pload, stream=True, timeout=10) 156 | num_generated_units = 0 157 | wav_list = [] 158 | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): 159 | if chunk: 160 | data = json.loads(chunk.decode()) 161 | if data["error_code"] == 0: 162 | output = data["text"][len(prompt):].strip() 163 | output_unit = list(map(int, data["unit"].strip().split())) 164 | state.messages[-1][-1] = (output, data["unit"].strip()) 165 | 166 | # vocoder 167 | new_units = output_unit[num_generated_units:] 168 | if len(new_units) >= chunk_size: 169 | num_generated_units = len(output_unit) 170 | x = {"code": torch.LongTensor(new_units).view(1, -1).cuda()} 171 | wav = vocoder(x, True) 172 | wav_list.append(wav.detach().cpu().numpy()) 173 | 174 | if len(wav_list) > 0: 175 | wav_full = np.concatenate(wav_list) 176 | return_value = (16000, wav_full) 177 | else: 178 | return_value = None 179 | 180 | yield (state, state.messages[-1][-1][0], state.messages[-1][-1][1], return_value) 181 | else: 182 | output = data["text"] + f" (error_code: {data['error_code']})" 183 | state.messages[-1][-1] = output 184 | yield (state, "", "", None) 185 | return 186 | time.sleep(0.03) 187 | except requests.exceptions.RequestException as e: 188 | state.messages[-1][-1] = server_error_msg 189 | yield (state, "", "", None) 190 | return 191 | 192 | if num_generated_units < len(output_unit): 193 | new_units = output_unit[num_generated_units:] 194 | num_generated_units = len(output_unit) 195 | x = { 196 | "code": torch.LongTensor(new_units).view(1, -1).cuda() 197 | } 198 | wav = vocoder(x, True) 199 | wav_list.append(wav.detach().cpu().numpy()) 200 | 201 | if len(wav_list) > 0: 202 | wav_full = np.concatenate(wav_list) 203 | return_value = (16000, wav_full) 204 | else: 205 | return_value = None 206 | 207 | yield (state, state.messages[-1][-1][0], state.messages[-1][-1][1], return_value) 208 | 209 | finish_tstamp = time.time() 210 | logger.info(f"{output}") 211 | logger.info(f"{output_unit}") 212 | 213 | 214 | title_markdown = (""" 215 | # 🎧 LLaMA-Omni: Seamless Speech Interaction with Large Language Models 216 | """) 217 | 218 | block_css = """ 219 | 220 | #buttons button { 221 | min-width: min(120px,100%); 222 | } 223 | 224 | """ 225 | 226 | def build_demo(embed_mode, vocoder, cur_dir=None, concurrency_count=10): 227 | with gr.Blocks(title="LLaMA-Omni Speech Chatbot", theme=gr.themes.Default(), css=block_css) as demo: 228 | state = gr.State() 229 | 230 | if not embed_mode: 231 | gr.Markdown(title_markdown) 232 | 233 | with gr.Row(elem_id="model_selector_row"): 234 | model_selector = gr.Dropdown( 235 | choices=models, 236 | value=models[0] if len(models) > 0 else "", 237 | interactive=True, 238 | show_label=False, 239 | container=False) 240 | 241 | with gr.Row(): 242 | audio_input_box = gr.Audio(sources=["upload", "microphone"], label="Speech Input") 243 | with gr.Accordion("Parameters", open=True) as parameter_row: 244 | temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature",) 245 | top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) 246 | max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max Output Tokens",) 247 | chunk_size = gr.Slider(minimum=10, maximum=500, value=40, step=10, interactive=True, label="Chunk Size",) 248 | 249 | if cur_dir is None: 250 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 251 | gr.Examples(examples=[ 252 | [f"{cur_dir}/examples/vicuna_1.wav"], 253 | [f"{cur_dir}/examples/vicuna_2.wav"], 254 | [f"{cur_dir}/examples/vicuna_3.wav"], 255 | [f"{cur_dir}/examples/vicuna_4.wav"], 256 | [f"{cur_dir}/examples/vicuna_5.wav"], 257 | [f"{cur_dir}/examples/helpful_base_1.wav"], 258 | [f"{cur_dir}/examples/helpful_base_2.wav"], 259 | [f"{cur_dir}/examples/helpful_base_3.wav"], 260 | [f"{cur_dir}/examples/helpful_base_4.wav"], 261 | [f"{cur_dir}/examples/helpful_base_5.wav"], 262 | ], inputs=[audio_input_box]) 263 | 264 | with gr.Row(): 265 | submit_btn = gr.Button(value="Send", variant="primary") 266 | clear_btn = gr.Button(value="Clear") 267 | 268 | text_output_box = gr.Textbox(label="Text Output", type="text") 269 | unit_output_box = gr.Textbox(label="Unit Output", type="text") 270 | audio_output_box = gr.Audio(label="Speech Output") 271 | 272 | url_params = gr.JSON(visible=False) 273 | 274 | submit_btn.click( 275 | add_speech, 276 | [state, audio_input_box], 277 | [state] 278 | ).then( 279 | http_bot, 280 | [state, model_selector, temperature, top_p, max_output_tokens, chunk_size], 281 | [state, text_output_box, unit_output_box, audio_output_box], 282 | concurrency_limit=concurrency_count 283 | ) 284 | 285 | clear_btn.click( 286 | clear_history, 287 | None, 288 | [state, audio_input_box, text_output_box, unit_output_box, audio_output_box], 289 | queue=False 290 | ) 291 | 292 | if args.model_list_mode == "once": 293 | demo.load( 294 | load_demo, 295 | [url_params], 296 | [state, model_selector], 297 | js=get_window_url_params 298 | ) 299 | elif args.model_list_mode == "reload": 300 | demo.load( 301 | load_demo_refresh_model_list, 302 | None, 303 | [state, model_selector], 304 | queue=False 305 | ) 306 | else: 307 | raise ValueError(f"Unknown model list mode: {args.model_list_mode}") 308 | 309 | return demo 310 | 311 | 312 | def build_vocoder(args): 313 | global vocoder 314 | if args.vocoder is None: 315 | return None 316 | with open(args.vocoder_cfg) as f: 317 | vocoder_cfg = json.load(f) 318 | vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg).cuda() 319 | 320 | 321 | if __name__ == "__main__": 322 | parser = argparse.ArgumentParser() 323 | parser.add_argument("--host", type=str, default="0.0.0.0") 324 | parser.add_argument("--port", type=int) 325 | parser.add_argument("--controller-url", type=str, default="http://localhost:21001") 326 | parser.add_argument("--concurrency-count", type=int, default=16) 327 | parser.add_argument("--model-list-mode", type=str, default="once", 328 | choices=["once", "reload"]) 329 | parser.add_argument("--share", action="store_true") 330 | parser.add_argument("--moderate", action="store_true") 331 | parser.add_argument("--embed", action="store_true") 332 | parser.add_argument("--vocoder", type=str) 333 | parser.add_argument("--vocoder-cfg", type=str) 334 | args = parser.parse_args() 335 | logger.info(f"args: {args}") 336 | 337 | models = get_model_list() 338 | build_vocoder(args) 339 | 340 | logger.info(args) 341 | demo = build_demo(args.embed, vocoder, concurrency_count=args.concurrency_count) 342 | demo.queue( 343 | api_open=False 344 | ).launch( 345 | server_name=args.host, 346 | server_port=args.port, 347 | share=args.share 348 | ) -------------------------------------------------------------------------------- /ola/serve/model_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | A model worker executes the model. 3 | """ 4 | import argparse 5 | import asyncio 6 | import json 7 | import time 8 | import threading 9 | import uuid 10 | 11 | from fastapi import FastAPI, Request, BackgroundTasks 12 | from fastapi.responses import StreamingResponse 13 | import requests 14 | import torch 15 | import uvicorn 16 | import whisper 17 | import numpy as np 18 | from functools import partial 19 | 20 | from transformers import PreTrainedTokenizer 21 | 22 | from omni_speech.constants import WORKER_HEART_BEAT_INTERVAL 23 | from omni_speech.utils import (build_logger, server_error_msg, 24 | pretty_print_semaphore) 25 | from omni_speech.model.builder import load_pretrained_model 26 | from omni_speech.constants import SPEECH_TOKEN_INDEX, DEFAULT_SPEECH_TOKEN 27 | from omni_speech.datasets.preprocess import tokenizer_speech_token 28 | from transformers import TextIteratorStreamer 29 | from threading import Thread 30 | 31 | 32 | GB = 1 << 30 33 | 34 | worker_id = str(uuid.uuid4())[:6] 35 | logger = build_logger("model_worker", f"model_worker_{worker_id}.log") 36 | global_counter = 0 37 | 38 | model_semaphore = None 39 | 40 | 41 | def heart_beat_worker(controller): 42 | 43 | while True: 44 | time.sleep(WORKER_HEART_BEAT_INTERVAL) 45 | controller.send_heart_beat() 46 | 47 | 48 | def load_speech(audio, input_type, mel_size, speech_normalize): 49 | speech = np.array(audio, dtype=np.float32) 50 | if input_type == "raw": 51 | speech = torch.from_numpy(speech) 52 | if speech_normalize: 53 | speech = torch.nn.functional.layer_norm(speech, speech.shape) 54 | elif input_type == "mel": 55 | speech = whisper.pad_or_trim(speech) 56 | speech = whisper.log_mel_spectrogram(speech, n_mels=mel_size).permute(1, 0) 57 | return speech 58 | 59 | 60 | def build_unit_tokenizer(vocab_size): 61 | import os 62 | from transformers import BertTokenizer 63 | with open("unit_vocab.txt", "w") as f: 64 | for i in range(vocab_size + 1): 65 | f.write(str(i) + "\n") 66 | tokenizer = BertTokenizer(vocab_file="unit_vocab.txt") 67 | os.remove("unit_vocab.txt") 68 | return tokenizer 69 | 70 | 71 | class ModelWorker: 72 | def __init__(self, controller_addr, worker_addr, 73 | worker_id, no_register, 74 | model_path, model_base, model_name, 75 | load_8bit, load_4bit, device, input_type, mel_size, s2s, is_lora, use_flash_attn=False): 76 | self.controller_addr = controller_addr 77 | self.worker_addr = worker_addr 78 | self.worker_id = worker_id 79 | self.device = device 80 | self.model_name = model_name 81 | self.input_type = input_type 82 | self.mel_size = mel_size 83 | self.tokenizer, self.model, self.context_len = load_pretrained_model( 84 | model_path, model_base, is_lora=is_lora, s2s=s2s, load_8bit=load_8bit, load_4bit=load_4bit, device=self.device, use_flash_attn=use_flash_attn) 85 | self.unit_tokenizer = build_unit_tokenizer(self.model.config.unit_vocab_size) 86 | 87 | if not no_register: 88 | self.register_to_controller() 89 | self.heart_beat_thread = threading.Thread( 90 | target=heart_beat_worker, args=(self,), daemon=True) 91 | self.heart_beat_thread.start() 92 | 93 | def register_to_controller(self): 94 | logger.info("Register to controller") 95 | 96 | url = self.controller_addr + "/register_worker" 97 | data = { 98 | "worker_name": self.worker_addr, 99 | "check_heart_beat": True, 100 | "worker_status": self.get_status() 101 | } 102 | r = requests.post(url, json=data) 103 | assert r.status_code == 200 104 | 105 | def send_heart_beat(self): 106 | logger.info(f"Send heart beat. Models: {[self.model_name]}. " 107 | f"Semaphore: {pretty_print_semaphore(model_semaphore)}. " 108 | f"global_counter: {global_counter}") 109 | 110 | url = self.controller_addr + "/receive_heart_beat" 111 | 112 | while True: 113 | try: 114 | ret = requests.post(url, json={ 115 | "worker_name": self.worker_addr, 116 | "queue_length": self.get_queue_length()}, timeout=5) 117 | exist = ret.json()["exist"] 118 | break 119 | except requests.exceptions.RequestException as e: 120 | logger.error(f"heart beat error: {e}") 121 | time.sleep(5) 122 | 123 | if not exist: 124 | self.register_to_controller() 125 | 126 | def get_queue_length(self): 127 | if model_semaphore is None: 128 | return 0 129 | else: 130 | return args.limit_model_concurrency - model_semaphore._value + (len( 131 | model_semaphore._waiters) if model_semaphore._waiters is not None else 0) 132 | 133 | def get_status(self): 134 | return { 135 | "model_names": [self.model_name], 136 | "speed": 1, 137 | "queue_length": self.get_queue_length(), 138 | } 139 | 140 | @torch.inference_mode() 141 | def generate_stream(self, params): 142 | tokenizer, model = self.tokenizer, self.model 143 | 144 | prompt = params["prompt"] 145 | ori_prompt = prompt 146 | audio = params.get("audio", None) 147 | if audio is not None and len(audio) > 0: 148 | speech = load_speech(audio, self.input_type, self.mel_size, self.model.config.speech_normalize) 149 | speech_length = torch.LongTensor([speech.shape[0]]).unsqueeze(0).to(self.device) 150 | speech_tensor = speech.unsqueeze(0).to(self.device, dtype=torch.float16) 151 | speech_args = {"speech": speech_tensor, "speech_lengths": speech_length} 152 | else: 153 | speech = None 154 | speech_args = {} 155 | 156 | temperature = float(params.get("temperature", 1.0)) 157 | top_p = float(params.get("top_p", 1.0)) 158 | max_context_length = getattr(model.config, 'max_position_embeddings', 2048) 159 | max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024) 160 | stop_str = params.get("stop", None) 161 | do_sample = True if temperature > 0.001 else False 162 | 163 | input_ids = tokenizer_speech_token(prompt, tokenizer, return_tensors='pt').unsqueeze(0).to(self.device) 164 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15) 165 | streamer_unit = TextIteratorStreamer(self.unit_tokenizer, skip_prompt=False, skip_special_tokens=True, timeout=15) 166 | 167 | # max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens) 168 | 169 | if max_new_tokens < 1: 170 | yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0" 171 | return 172 | 173 | thread = Thread(target=model.generate, kwargs=dict( 174 | inputs=input_ids, 175 | do_sample=do_sample, 176 | temperature=temperature, 177 | top_p=top_p, 178 | max_new_tokens=max_new_tokens, 179 | streamer=streamer, 180 | streamer_unit=streamer_unit, 181 | streaming_unit_gen=True, 182 | use_cache=True, 183 | **speech_args 184 | )) 185 | thread.start() 186 | 187 | generated_text = ori_prompt 188 | for new_text in streamer: 189 | generated_text += new_text 190 | generated_unit = " ".join(map(str, streamer_unit.token_cache)) 191 | if generated_text.endswith(stop_str): 192 | generated_text = generated_text[:-len(stop_str)] 193 | yield json.dumps({"text": generated_text, "unit": generated_unit, "error_code": 0}).encode() + b"\0" 194 | 195 | def generate_stream_gate(self, params): 196 | try: 197 | for x in self.generate_stream(params): 198 | yield x 199 | except ValueError as e: 200 | print("Caught ValueError:", e) 201 | ret = { 202 | "text": server_error_msg, 203 | "error_code": 1, 204 | } 205 | yield json.dumps(ret).encode() + b"\0" 206 | except torch.cuda.CudaError as e: 207 | print("Caught torch.cuda.CudaError:", e) 208 | ret = { 209 | "text": server_error_msg, 210 | "error_code": 1, 211 | } 212 | yield json.dumps(ret).encode() + b"\0" 213 | except Exception as e: 214 | print("Caught Unknown Error", e) 215 | ret = { 216 | "text": server_error_msg, 217 | "error_code": 1, 218 | } 219 | yield json.dumps(ret).encode() + b"\0" 220 | 221 | 222 | app = FastAPI() 223 | 224 | 225 | def release_model_semaphore(fn=None): 226 | model_semaphore.release() 227 | if fn is not None: 228 | fn() 229 | 230 | 231 | @app.post("/worker_generate_stream") 232 | async def generate_stream(request: Request): 233 | global model_semaphore, global_counter 234 | global_counter += 1 235 | params = await request.json() 236 | 237 | if model_semaphore is None: 238 | model_semaphore = asyncio.Semaphore(args.limit_model_concurrency) 239 | await model_semaphore.acquire() 240 | worker.send_heart_beat() 241 | generator = worker.generate_stream_gate(params) 242 | background_tasks = BackgroundTasks() 243 | background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) 244 | return StreamingResponse(generator, background=background_tasks) 245 | 246 | 247 | @app.post("/worker_get_status") 248 | async def get_status(request: Request): 249 | return worker.get_status() 250 | 251 | 252 | if __name__ == "__main__": 253 | parser = argparse.ArgumentParser() 254 | parser.add_argument("--host", type=str, default="localhost") 255 | parser.add_argument("--port", type=int, default=21002) 256 | parser.add_argument("--worker-address", type=str, 257 | default="http://localhost:21002") 258 | parser.add_argument("--controller-address", type=str, 259 | default="http://localhost:21001") 260 | parser.add_argument("--model-path", type=str, default="facebook/opt-350m") 261 | parser.add_argument("--model-base", type=str, default=None) 262 | parser.add_argument("--model-name", type=str) 263 | parser.add_argument("--device", type=str, default="cuda") 264 | parser.add_argument("--limit-model-concurrency", type=int, default=5) 265 | parser.add_argument("--stream-interval", type=int, default=1) 266 | parser.add_argument("--no-register", action="store_true") 267 | parser.add_argument("--load-8bit", action="store_true") 268 | parser.add_argument("--load-4bit", action="store_true") 269 | parser.add_argument("--use-flash-attn", action="store_true") 270 | parser.add_argument("--input-type", type=str, default="mel") 271 | parser.add_argument("--mel-size", type=int, default=128) 272 | parser.add_argument("--s2s", action="store_true", default=False) 273 | parser.add_argument("--is-lora", action="store_true", default=False) 274 | args = parser.parse_args() 275 | logger.info(f"args: {args}") 276 | 277 | worker = ModelWorker(args.controller_address, 278 | args.worker_address, 279 | worker_id, 280 | args.no_register, 281 | args.model_path, 282 | args.model_base, 283 | args.model_name, 284 | args.load_8bit, 285 | args.load_4bit, 286 | args.device, 287 | args.input_type, 288 | args.mel_size, 289 | args.s2s, 290 | args.is_lora, 291 | use_flash_attn=args.use_flash_attn) 292 | uvicorn.run(app, host=args.host, port=args.port, log_level="info") -------------------------------------------------------------------------------- /ola/utils.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/haotian-liu/LLaVA. Below is the original copyright: 2 | # Copyright 2023 Haotian Liu 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import sys 18 | import torch 19 | import logging 20 | import logging.handlers 21 | import transformers 22 | 23 | from ola.constants import LOGDIR 24 | 25 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 26 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 27 | 28 | handler = None 29 | 30 | 31 | def build_logger(logger_name, logger_filename): 32 | global handler 33 | 34 | formatter = logging.Formatter( 35 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 36 | datefmt="%Y-%m-%d %H:%M:%S", 37 | ) 38 | 39 | # Set the format of root handlers 40 | if not logging.getLogger().handlers: 41 | logging.basicConfig(level=logging.INFO) 42 | logging.getLogger().handlers[0].setFormatter(formatter) 43 | 44 | # Redirect stdout and stderr to loggers 45 | stdout_logger = logging.getLogger("stdout") 46 | stdout_logger.setLevel(logging.INFO) 47 | sl = StreamToLogger(stdout_logger, logging.INFO) 48 | sys.stdout = sl 49 | 50 | stderr_logger = logging.getLogger("stderr") 51 | stderr_logger.setLevel(logging.ERROR) 52 | sl = StreamToLogger(stderr_logger, logging.ERROR) 53 | sys.stderr = sl 54 | 55 | # Get logger 56 | logger = logging.getLogger(logger_name) 57 | logger.setLevel(logging.INFO) 58 | 59 | # Add a file handler for all loggers 60 | if handler is None: 61 | os.makedirs(LOGDIR, exist_ok=True) 62 | filename = os.path.join(LOGDIR, logger_filename) 63 | handler = logging.handlers.TimedRotatingFileHandler( 64 | filename, when='D', utc=True, encoding='UTF-8') 65 | handler.setFormatter(formatter) 66 | 67 | for name, item in logging.root.manager.loggerDict.items(): 68 | if isinstance(item, logging.Logger): 69 | item.addHandler(handler) 70 | 71 | return logger 72 | 73 | 74 | class StreamToLogger(object): 75 | """ 76 | Fake file-like stream object that redirects writes to a logger instance. 77 | """ 78 | def __init__(self, logger, log_level=logging.INFO): 79 | self.terminal = sys.stdout 80 | self.logger = logger 81 | self.log_level = log_level 82 | self.linebuf = '' 83 | 84 | def __getattr__(self, attr): 85 | return getattr(self.terminal, attr) 86 | 87 | def write(self, buf): 88 | temp_linebuf = self.linebuf + buf 89 | self.linebuf = '' 90 | for line in temp_linebuf.splitlines(True): 91 | # From the io.TextIOWrapper docs: 92 | # On output, if newline is None, any '\n' characters written 93 | # are translated to the system default line separator. 94 | # By default sys.stdout.write() expects '\n' newlines and then 95 | # translates them so this is still cross platform. 96 | if line[-1] == '\n': 97 | self.logger.log(self.log_level, line.rstrip()) 98 | else: 99 | self.linebuf += line 100 | 101 | def flush(self): 102 | if self.linebuf != '': 103 | self.logger.log(self.log_level, self.linebuf.rstrip()) 104 | self.linebuf = '' 105 | 106 | 107 | def maybe_zero_3(param, ignore_status=False, name=None): 108 | from deepspeed import zero 109 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 110 | if hasattr(param, "ds_id"): 111 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 112 | if not ignore_status: 113 | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") 114 | with zero.GatheredParameters([param]): 115 | param = param.data.detach().cpu().clone() 116 | else: 117 | param = param.detach().cpu().clone() 118 | return param 119 | 120 | 121 | # Borrowed from peft.utils.get_peft_model_state_dict 122 | def get_peft_state_maybe_zero_3(named_params, bias): 123 | if bias == "none": 124 | to_return = {k: t for k, t in named_params if "lora_" in k} 125 | elif bias == "all": 126 | to_return = {k: t for k, t in named_params if "lora_" in k or "bias" in k} 127 | elif bias == "lora_only": 128 | to_return = {} 129 | maybe_lora_bias = {} 130 | lora_bias_names = set() 131 | for k, t in named_params: 132 | if "lora_" in k: 133 | to_return[k] = t 134 | bias_name = k.split("lora_")[0] + "bias" 135 | lora_bias_names.add(bias_name) 136 | elif "bias" in k: 137 | maybe_lora_bias[k] = t 138 | for k, t in maybe_lora_bias: 139 | if bias_name in lora_bias_names: 140 | to_return[bias_name] = t 141 | else: 142 | raise NotImplementedError 143 | to_return = {k: maybe_zero_3(v, ignore_status=True) for k, v in to_return.items()} 144 | return to_return 145 | 146 | 147 | def get_peft_state_non_lora_maybe_zero_3(named_params, require_grad_only=True): 148 | to_return = {k: t for k, t in named_params if "lora_" not in k} 149 | if require_grad_only: 150 | to_return = {k: t for k, t in to_return.items() if t.requires_grad} 151 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} 152 | return to_return 153 | 154 | 155 | def get_speech_projector_state_maybe_zero_3(named_params, keys_to_match): 156 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 157 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} 158 | return to_return 159 | 160 | def lengths_to_padding_mask(lens): 161 | bsz, max_lens = lens.size(0), torch.max(lens).item() 162 | mask = torch.arange(max_lens).to(lens.device).view(1, max_lens) 163 | mask = mask.expand(bsz, -1) >= lens.view(bsz, 1).expand(-1, max_lens) 164 | return mask 165 | 166 | 167 | def lengths_to_mask(lens): 168 | return ~lengths_to_padding_mask(lens) 169 | 170 | 171 | def disable_torch_init(): 172 | """ 173 | Disable the redundant torch default initialization to accelerate model creation. 174 | """ 175 | import torch 176 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 177 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 178 | 179 | 180 | def get_model_name_from_path(model_path): 181 | model_path = model_path.strip("/") 182 | model_paths = model_path.split("/") 183 | if model_paths[-1].startswith('checkpoint-'): 184 | return model_paths[-2] + "_" + model_paths[-1] 185 | else: 186 | return model_paths[-1] 187 | 188 | 189 | def violates_moderation(text): 190 | """ 191 | Check whether the text violates OpenAI moderation API. 192 | """ 193 | url = "https://api.openai.com/v1/moderations" 194 | headers = {"Content-Type": "application/json", 195 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 196 | text = text.replace("\n", "") 197 | data = "{" + '"input": ' + f'"{text}"' + "}" 198 | data = data.encode("utf-8") 199 | try: 200 | ret = requests.post(url, headers=headers, data=data, timeout=5) 201 | flagged = ret.json()["results"][0]["flagged"] 202 | except requests.exceptions.RequestException as e: 203 | flagged = False 204 | except KeyError as e: 205 | flagged = False 206 | 207 | return flagged 208 | 209 | 210 | def pretty_print_semaphore(semaphore): 211 | if semaphore is None: 212 | return "None" 213 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "ola" 7 | version = "1.0.0" 8 | description = "Omni-Modal Language Model" 9 | readme = "README.md" 10 | requires-python = ">=3.10" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.1.2", "torchvision==0.16.2", "torchaudio==2.1.2", 17 | "transformers==4.43.4", "tokenizers==0.19.1", "sentencepiece==0.1.99", "shortuuid", 18 | "accelerate==0.33.0", "peft==0.11.1", "bitsandbytes==0.43.1", 19 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.2.2", 20 | "gradio==4.43.0", "gradio_client==1.3.0", 21 | "requests", "httpx==0.27.2", "uvicorn", "fastapi", "soundfile", 22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.9.16", 23 | "openai-whisper", "setuptools==59.5.0", "omegaconf==2.0.6", "loguru", "av", "librosa", 24 | ] 25 | 26 | [project.optional-dependencies] 27 | train = ["deepspeed==0.12.6", "ninja", "wandb", "tensorboardX"] 28 | build = ["build", "twine"] 29 | 30 | [tool.setuptools.packages.find] 31 | exclude = ["data", "checkpoints", "logs", "models", "fairseq", "flash-attention"] 32 | 33 | [tool.wheel] 34 | exclude = ["data", "checkpoints", "logs", "models", "fairseq", "flash-attention"] 35 | -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Ola-Omni/Ola/666df3b345a252b218b25b84b2ec4c0e0b631b83/scripts/.DS_Store -------------------------------------------------------------------------------- /scripts/finetune_ola.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export LOWRES_RESIZE=384x32 4 | export VIDEO_RESIZE="0x32" 5 | export HIGHRES_BASE="0x32" 6 | export MAXRES=1536 7 | export MINRES=0 8 | export VIDEO_MAXRES=448 9 | export VIDEO_MINRES=288 10 | export PAD2STRIDE=1 11 | export FORCE_NO_DOWNSAMPLE=1 12 | export LOAD_VISION_EARLY=1 13 | 14 | export PYTHONPATH=/path/to/Ola:$PYTHONPATH 15 | 16 | EXP_NAME="ola_7b" 17 | DATA='/path/to/data.json' 18 | 19 | CHECKPOINT='/path/to/Ola_7b' 20 | 21 | echo $MASTER_ADDR; echo $nnode; echo $nrank 22 | 23 | torchrun --nproc_per_node 8 --nnodes=$nnode --node_rank=$nrank --master_addr=$MASTER_ADDR --master_port=12324 \ 24 | ola/train/train.py \ 25 | --deepspeed ./scripts/zero2.json \ 26 | --run_name $EXP_NAME \ 27 | --model_name_or_path $CHECKPOINT \ 28 | --pretrain_speech_projector $CHECKPOINT/speech_projector.bin \ 29 | --vision_tower $VISION_TOWER \ 30 | --mm_projector_type ola_mlp \ 31 | --speech_projector_type "linear" \ 32 | --mm_vision_select_layer -1 \ 33 | --mm_use_im_patch_token False \ 34 | --tune_speech_adapter False \ 35 | --version qwen_1_5 \ 36 | --data_path $DATA \ 37 | --bf16 True \ 38 | --output_dir ./checkpoints/$EXP_NAME \ 39 | --sample_independently True \ 40 | --fix_speech_encoder True \ 41 | --freeze_mm_vision_tower True \ 42 | --speech_encoder "./pretrained/large-v3.pt" \ 43 | --music_encoder "./pretrained/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt" \ 44 | --speech_encoder_type "dual" \ 45 | --speech_encoder_hidden_size 2048 \ 46 | --speech_encoder_ds_rate 10 \ 47 | --num_train_epochs 1 \ 48 | --per_device_train_batch_size 2 \ 49 | --per_device_eval_batch_size 1 \ 50 | --gradient_accumulation_steps 1 \ 51 | --evaluation_strategy "no" \ 52 | --save_strategy "steps" \ 53 | --save_steps 1000 \ 54 | --save_total_limit 1 \ 55 | --learning_rate 1e-5 \ 56 | --weight_decay 0.0 \ 57 | --warmup_ratio 0.05 \ 58 | --min_lr_ratio 0.01 \ 59 | --lr_scheduler_type "cosine" \ 60 | --logging_steps 1 \ 61 | --tf32 True \ 62 | --model_max_length 16384 \ 63 | --gradient_checkpointing True \ 64 | --dataloader_num_workers 8 \ 65 | --frames_upbound 48 \ 66 | --lazy_preprocess True \ 67 | --report_to none -------------------------------------------------------------------------------- /scripts/finetune_ola_image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export LOWRES_RESIZE=384x32 4 | export VIDEO_RESIZE="0x32" 5 | export HIGHRES_BASE="0x32" 6 | export MAXRES=1536 7 | export MINRES=0 8 | export VIDEO_MAXRES=448 9 | export VIDEO_MINRES=288 10 | export PAD2STRIDE=1 11 | export FORCE_NO_DOWNSAMPLE=1 12 | export LOAD_VISION_EARLY=1 13 | 14 | export PYTHONPATH=/path/to/Ola:$PYTHONPATH 15 | 16 | EXP_NAME="ola_7b_stage2" 17 | DATA='/path/to/data.json' 18 | 19 | CHECKPOINT='/path/to/Ola_Image' 20 | 21 | echo $MASTER_ADDR; echo $nnode; echo $nrank 22 | 23 | torchrun --nproc_per_node 8 --nnodes=$nnode --node_rank=$nrank --master_addr=$MASTER_ADDR --master_port=12324 \ 24 | ola/train/train.py \ 25 | --deepspeed ./scripts/zero2.json \ 26 | --run_name $EXP_NAME \ 27 | --model_name_or_path $CHECKPOINT \ 28 | --pretrain_speech_projector $CHECKPOINT/speech_projector.bin \ 29 | --vision_tower $VISION_TOWER \ 30 | --mm_projector_type ola_mlp \ 31 | --speech_projector_type "linear" \ 32 | --mm_vision_select_layer -1 \ 33 | --mm_use_im_patch_token False \ 34 | --tune_speech_adapter False \ 35 | --version qwen_1_5 \ 36 | --data_path $DATA \ 37 | --bf16 True \ 38 | --output_dir ./checkpoints/$EXP_NAME \ 39 | --sample_independently True \ 40 | --fix_speech_encoder True \ 41 | --freeze_mm_vision_tower True \ 42 | --speech_encoder "./pretrained/large-v3.pt" \ 43 | --music_encoder "./pretrained/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt" \ 44 | --speech_encoder_type "dual" \ 45 | --speech_encoder_hidden_size 2048 \ 46 | --speech_encoder_ds_rate 10 \ 47 | --num_train_epochs 1 \ 48 | --per_device_train_batch_size 2 \ 49 | --per_device_eval_batch_size 1 \ 50 | --gradient_accumulation_steps 1 \ 51 | --evaluation_strategy "no" \ 52 | --save_strategy "steps" \ 53 | --save_steps 1000 \ 54 | --save_total_limit 1 \ 55 | --learning_rate 1e-5 \ 56 | --weight_decay 0.0 \ 57 | --warmup_ratio 0.05 \ 58 | --min_lr_ratio 0.01 \ 59 | --lr_scheduler_type "cosine" \ 60 | --logging_steps 1 \ 61 | --tf32 True \ 62 | --model_max_length 16384 \ 63 | --gradient_checkpointing True \ 64 | --dataloader_num_workers 8 \ 65 | --frames_upbound 64 \ 66 | --lazy_preprocess True \ 67 | --report_to none -------------------------------------------------------------------------------- /scripts/finetune_ola_video.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export LOWRES_RESIZE=384x32 4 | export VIDEO_RESIZE="0x32" 5 | export HIGHRES_BASE="0x32" 6 | export MAXRES=1536 7 | export MINRES=0 8 | export VIDEO_MAXRES=448 9 | export VIDEO_MINRES=288 10 | export PAD2STRIDE=1 11 | export FORCE_NO_DOWNSAMPLE=1 12 | export LOAD_VISION_EARLY=1 13 | 14 | export PYTHONPATH=/path/to/Ola:$PYTHONPATH 15 | 16 | EXP_NAME="ola_7b_stage3" 17 | DATA='/path/to/data.json' 18 | 19 | CHECKPOINT='/path/to/Ola_Video' 20 | 21 | echo $MASTER_ADDR; echo $nnode; echo $nrank 22 | 23 | torchrun --nproc_per_node 8 --nnodes=$nnode --node_rank=$nrank --master_addr=$MASTER_ADDR --master_port=12324 \ 24 | ola/train/train.py \ 25 | --deepspeed ./scripts/zero2.json \ 26 | --run_name $EXP_NAME \ 27 | --model_name_or_path $CHECKPOINT \ 28 | --pretrain_speech_projector $CHECKPOINT/speech_projector.bin \ 29 | --vision_tower $VISION_TOWER \ 30 | --mm_projector_type ola_mlp \ 31 | --speech_projector_type "linear" \ 32 | --mm_vision_select_layer -1 \ 33 | --mm_use_im_patch_token False \ 34 | --tune_speech_adapter False \ 35 | --version qwen_1_5 \ 36 | --data_path $DATA \ 37 | --bf16 True \ 38 | --output_dir ./checkpoints/$EXP_NAME \ 39 | --sample_independently True \ 40 | --fix_speech_encoder True \ 41 | --freeze_mm_vision_tower True \ 42 | --speech_encoder "./pretrained/large-v3.pt" \ 43 | --music_encoder "./pretrained/BEATs_iter3_plus_AS2M_finetuned_on_AS2M_cpt2.pt" \ 44 | --speech_encoder_type "dual" \ 45 | --speech_encoder_hidden_size 2048 \ 46 | --speech_encoder_ds_rate 10 \ 47 | --num_train_epochs 1 \ 48 | --per_device_train_batch_size 2 \ 49 | --per_device_eval_batch_size 1 \ 50 | --gradient_accumulation_steps 1 \ 51 | --evaluation_strategy "no" \ 52 | --save_strategy "steps" \ 53 | --save_steps 1000 \ 54 | --save_total_limit 1 \ 55 | --learning_rate 1e-5 \ 56 | --weight_decay 0.0 \ 57 | --warmup_ratio 0.05 \ 58 | --min_lr_ratio 0.01 \ 59 | --lr_scheduler_type "cosine" \ 60 | --logging_steps 1 \ 61 | --tf32 True \ 62 | --model_max_length 16384 \ 63 | --gradient_checkpointing True \ 64 | --dataloader_num_workers 8 \ 65 | --frames_upbound 48 \ 66 | --lazy_preprocess True \ 67 | --report_to none -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | 23 | "zero_optimization": { 24 | "stage": 3, 25 | "offload_optimizer": { 26 | "device": "none", 27 | "pin_memory": true 28 | }, 29 | "offload_param": { 30 | "device": "none", 31 | "pin_memory": true 32 | }, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "sub_group_size": 1e9, 36 | "reduce_bucket_size": "auto", 37 | "stage3_prefetch_bucket_size": "auto", 38 | "stage3_param_persistence_threshold": "auto", 39 | "stage3_max_live_parameters": 1e9, 40 | "stage3_max_reuse_distance": 1e9, 41 | "stage3_gather_16bit_weights_on_model_save": true 42 | }, 43 | 44 | "gradient_accumulation_steps": "auto", 45 | "gradient_clipping": "auto", 46 | "steps_per_print": 100, 47 | "train_batch_size": "auto", 48 | "train_micro_batch_size_per_gpu": "auto", 49 | "wall_clock_breakdown": false 50 | } -------------------------------------------------------------------------------- /tools/convert_mp4_wav.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import argparse 4 | 5 | parser = argparse.ArgumentParser(description='argparse testing') 6 | parser.add_argument('--idx', type=int, help='index of the patch', default=0) 7 | args = parser.parse_args() 8 | idx = args.idx 9 | 10 | path = '/path/to/video_folder' 11 | 12 | filelist = [] 13 | for home, dirs, files in os.walk(path): 14 | for filename in files: 15 | if filename.endswith('.mp4') or filename.endswith('.mkv'): 16 | filelist.append(os.path.join(home, filename)) 17 | 18 | print(len(filelist)) 19 | 20 | from tqdm import tqdm 21 | for file in tqdm(filelist): 22 | 23 | if file.endswith('.mp4'): 24 | target_path = file.replace('.mp4', '.wav') 25 | elif file.endswith('.mkv'): 26 | target_path = file.replace('.mkv', '.wav') 27 | else: 28 | raise NotImplementedError 29 | 30 | target_path = target_path.replace('/path/to/video_folder', '/path/to/audio_folder') 31 | 32 | if os.path.exists(target_path): 33 | continue 34 | 35 | # create target dir 36 | target_dir = os.path.dirname(target_path) 37 | if not os.path.exists(target_dir): 38 | os.makedirs(target_dir) 39 | 40 | # convert 41 | os.system(f'ffmpeg -i {file} -ac 1 -ar 16000 -vn {target_path}') -------------------------------------------------------------------------------- /tools/create_patch.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os, io 3 | import json 4 | import random 5 | from PIL import Image 6 | from tqdm import tqdm 7 | import argparse 8 | 9 | json_out_dir = '/path/to/json' 10 | image_in_dir = '/path/to/image_in_dir' 11 | image_out_dir = '/path/to/image_out_dir' 12 | 13 | file_list = [] 14 | for root, dirs, files in os.walk(image_in_dir): 15 | for file in files: 16 | file_list.append(os.path.join(root, file)) 17 | file_list = sorted(file_list) 18 | 19 | file_count = 0 20 | now_file_name = image_out_dir + f'patch_{file_count:06d}' 21 | file = open(now_file_name, 'wb') 22 | now_btytes = file.tell() 23 | count_image = 0 24 | new_data = {} 25 | 26 | for image_path in tqdm(file_list): 27 | image_info = {} 28 | btyes = 0 29 | sizes = [] 30 | with open(image_path, 'rb') as img: 31 | img_data = img.read() 32 | btyes = file.write(img_data) 33 | sizes.append(btyes) 34 | patch_info = { 35 | "patch": now_file_name, 36 | "start_num": now_btytes, "size": sizes, 37 | } 38 | count_image += 1 39 | now_btytes = file.tell() 40 | if count_image == 10000: 41 | file.close() 42 | file_count += 1 43 | now_file_name = image_out_dir + f'patch_{file_count:06d}' 44 | file = open(now_file_name, 'wb') 45 | now_btytes = 0 46 | count_image = 0 47 | 48 | image_info['original_image_path'] = os.path.join(image_path) 49 | image_info['image'] = patch_info 50 | new_data[image_path.split('/')[-1]] = image_info 51 | if count_image == 1: 52 | print(new_data) 53 | file.close() 54 | 55 | with open(f'{json_out_dir}/patch_mapping.json', 'w') as f: 56 | json.dump(new_data, f, indent=4) 57 | 58 | 59 | --------------------------------------------------------------------------------