├── LICENSE ├── README.md ├── assets ├── framework.jpeg ├── intro.jpeg └── setok.png ├── cog.yaml ├── pyproject.toml ├── scripts ├── extract_mm_projector.py ├── finetune.sh ├── merge_lora_weights.py ├── pretrain_mm_proj.sh ├── train_setok.sh ├── zero2.json ├── zero3.json └── zero3_offload.json └── src ├── __init__.py ├── constants.py ├── conversation.py ├── data_preprocess.py ├── dataset ├── __init__.py ├── base_dataset.py ├── dataset_utils.py ├── editDataset.py ├── instructDataset.py ├── pairDataset.py └── vqa.py ├── mm_utils.py ├── model ├── __init__.py ├── apply_delta.py ├── builder.py ├── consolidate.py ├── diffusion │ ├── __init__.py │ ├── diffusion_utils.py │ ├── gaussian_diffusion.py │ └── respace.py ├── language_model │ └── setokim_llama.py ├── loss │ ├── __init__.py │ ├── diffloss.py │ ├── discriminator.py │ ├── mse.py │ ├── multilabel_constrastive.py │ ├── perceptual.py │ └── segmentation.py ├── make_delta.py ├── multimodal_encoder │ ├── builder.py │ ├── clip_encoder.py │ ├── eva_encoder.py │ ├── openclip_encoder.py │ └── openclip_processor.py ├── multimodal_generator │ └── builder.py ├── multimodal_projector │ └── builder.py ├── setok │ ├── __init__.py │ ├── clip_encoder.py │ ├── detokenizer.py │ ├── model.py │ ├── module.py │ ├── tokenizer.py │ └── utils.py ├── setokim_arch.py └── utils.py ├── train ├── llama_flash_attn_monkey_patch.py ├── llama_xformers_attn_monkey_patch.py ├── setok_trainer.py ├── setokim_trainer.py ├── train_mem.py ├── train_setok.py ├── train_setokim.py ├── train_xformers.py └── training_utils.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

4 |

Towards Semantic Equivalence of Tokenization in Multimodal LLM

5 |

6 | Shengqiong Wu 7 | · 8 | Hao Fei 9 | · 10 | Xiangtai Li 11 | · 12 | Jiayi Ji 13 | · 14 |
15 | Hanwang Zhang 16 | · 17 | Tat-seng Chua 18 | · 19 | Shuicheng Yan 20 |

21 |

22 | National University of Singapore · Skywork AI, Singapore · 23 |
Nanyang Technological University
24 |

25 |

26 | Work is done as an intern in Skywork AI, Hao Fei is the corresponding author. 27 |

28 | 29 |

30 | 31 | arXiv PDF 32 | 33 | Project Page 34 |

35 |
36 | 37 | ![avatar](./assets/intro.jpeg) 38 | 39 | 40 | ### Abstract 41 | 42 | Multimodal Large Language Models (MLLMs) have demonstrated exceptional capabilities in processing vision-language tasks. One of the crux of MLLMs lies in vision tokenization, which involves efficiently transforming input visual signals into feature representations that are most beneficial for LLMs. However, existing vision tokenizers, essential for semantic alignment between vision and language, remain problematic. Existing methods aggressively fragment visual input, corrupting the visual semantic integrity. To address this, this paper proposes a novel dynamic `Semantic-Equivalent Vision Tokenizer` (**SeTok**), which groups visual features into semantic units via a dynamic clustering algorithm, flexibly determining the number of tokens based on image complexity. The resulting vision tokens effectively preserve semantic integrity and capture both low-frequency and high-frequency visual features. 43 | ![avatar](./assets/setok.png) 44 | 45 | The proposed MLLM (**Setokim**) equipped with SeTok significantly demonstrates superior performance across various tasks, as evidenced by our experimental results. 46 |
47 | 48 |
49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | ## 📂 Getting Start 57 | 58 | First, prepare the code and set up the environment: 59 | ``` 60 | # download the code 61 | git clone https://github.com/ChocoWu/SeTok.git 62 | cd SeTok 63 | 64 | # install 65 | conda creat -n setok python=3.10 -y 66 | conda activate setok 67 | pip install --upgrade pip 68 | pip install -e . 69 | ``` 70 | 71 | ### 📚 Preparing Data 72 | 73 | To begin, please prepare the dataset. All datasets should be placed under the [`data/`](./data) directory, with the following structure: 74 | ``` 75 | data 76 | - ImageNet-1K 77 | - data.json 78 | - images 79 | - xxx.jpg 80 | - OpenImages 81 | - ALLaVA 82 | - GQA 83 | - data.json 84 | - images 85 | - xxx.jpg 86 | - OK-VQA 87 | - ... 88 | - InstructPix2Pix 89 | - Magicbrush 90 | ``` 91 | For details on how each dataset is processed, please refer to the following scripts: 92 | 93 | - [`TextImagePairDataset`](./src/dataset/pairDataset.py) 94 | - [`EditingDataset`](./src/dataset/editDataset.py) 95 | - [`InstructionTuningDataset`](./src/dataset/instructDataset.py) 96 | 97 | 98 | ### 🚀 Training 99 | Our training receipts involves three stages. 100 | 101 | - **Stage-1: Setok tokenizer training**. We use ImageNet-1K for reconstruction learning and OpenImages for both reconstruction and alignment learning. 102 | ``` 103 | # In [train_mem.py], activate Setok training: 104 | 105 | from train_setok import train 106 | train(attn_implementation="flash_attention_2") 107 | 108 | # Set the hyper-parameters in [train_setok.sh]. 109 | bash train_setok.sh 110 | ``` 111 | Make sure the dataset paths are correctly set in your config file or environment variables. 112 | 113 | 114 | - **Stage-2: Multimodal Pretraining**. In this stage, we focus on enhancing the alignment between text and image. We employ massive multimodal data, including ImageNet-1K and 28M text-image pair dataset, to train our model for conditional image generation and image captioning. 115 | ``` 116 | # In [train_mem.py], activate Setok training: 117 | 118 | from train_setokim import train 119 | train(attn_implementation="flash_attention_2") 120 | 121 | # Set the hyper-parameters in [pretrain_mm_proj.sh]. 122 | bash pretrain_mm_proj.sh 123 | 124 | ``` 125 | 126 | - **Stage-3: Instruction Tuning**. Building upon the pretrained weights, we further perform multimodal instruction tuning with both public datasets covering 127 | multimodal instruction datasets, fine-grained visual QA, and etc. 128 | ``` 129 | # Set the hyper-parameters in [finetune.sh]. 130 | bash finetune.sh 131 | ``` 132 | 133 | 134 | 135 | ## ✨ Citation 136 | 137 | If you use **SeTok** in your project, please kindly cite: 138 | 139 | ```bibtex 140 | @article{wu2024towards, 141 | title={Towards Semantic Equivalence of Tokenization in Multimodal LLM}, 142 | author={Wu, Shengqiong and Fei, Hao and Li, Xiangtai and Ji, Jiayi and Zhang, Hanwang and Chua, Tat-Seng and Yan, Shuicheng}, 143 | publisher={ICLR}, 144 | year={2025} 145 | } 146 | ``` 147 | 148 | ## Acknowledgments 149 | 150 | This work is heavily built based on [LLaVA](https://github.com/haotian-liu/LLaVA), [GroupViT](https://github.com/NVlabs/GroupViT), [MAR](https://github.com/LTH14/mar), [Blip-2](https://github.com/salesforce/LAVIS). 151 | Thanks to all the authors for their great work. -------------------------------------------------------------------------------- /assets/framework.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChocoWu/SeTok/55cfc9fa5bdf28955f05bb627df3e300a5693afb/assets/framework.jpeg -------------------------------------------------------------------------------- /assets/intro.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChocoWu/SeTok/55cfc9fa5bdf28955f05bb627df3e300a5693afb/assets/intro.jpeg -------------------------------------------------------------------------------- /assets/setok.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChocoWu/SeTok/55cfc9fa5bdf28955f05bb627df3e300a5693afb/assets/setok.png -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | 7 | python_version: "3.8" 8 | 9 | python_packages: 10 | - "torch==2.2.1" 11 | - "accelerate==0.27.2" 12 | - "bitsandbytes==0.41.0" 13 | - "deepspeed==0.13.6" 14 | - "einops-exts==0.0.4" 15 | - "einops==0.6.1" 16 | - "gradio==3.35.2" 17 | - "gradio_client==0.2.9" 18 | - "httpx==0.24.0" 19 | - "markdown2==2.4.10" 20 | - "numpy==1.24.4" 21 | - "peft==0.13.2" 22 | - "scikit-learn==1.3.2" 23 | - "sentencepiece==0.1.99" 24 | - "shortuuid==1.0.11" 25 | - "timm==0.9.16" 26 | - "tokenizers==0.20.3" 27 | - "torchvision==0.17.1" 28 | - "transformers==4.46.3" 29 | - "diffusers==0.27.2" 30 | - "scipy==1.10.1" 31 | - "wandb==0.15.12" 32 | - "wavedrom==2.0.3.post3" 33 | - "Pygments==2.16.1" 34 | run: 35 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget 36 | 37 | # predict.py defines how predictions are run on your model 38 | predict: "predict.py:Predictor" 39 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=61.0"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "setok" 7 | version = "0.0.0" 8 | description = "Towards Semantic Equivalence of Tokenization in Multimodal LLM" 9 | readme = "README.md" 10 | requires-python = ">=3.8" 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apache Software License", 14 | ] 15 | dependencies = [ 16 | "torch==2.2.1", "torchvision==0.17.1", 17 | "diffdist==0.1", "diffusers==0.27.2", "scipy==1.10.1", 18 | "transformers==4.46.3", "tokenizers==0.20.3", "sentencepiece==0.1.99", "shortuuid", 19 | "accelerate==0.27.2", "peft", "bitsandbytes", 20 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.3.2", 21 | "requests", "httpx==0.24.0", "uvicorn", "fastapi", 22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.9.16", 23 | ] 24 | 25 | [project.optional-dependencies] 26 | train = ["deepspeed==0.13.6", "ninja", "wandb"] 27 | build = ["build", "twine"] 28 | 29 | [project.urls] 30 | "Homepage" = "https://llava-vl.github.io" 31 | "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues" 32 | 33 | [tool.setuptools.packages.find] 34 | exclude = ["assets*", "scripts*", "tests*"] 35 | 36 | [tool.wheel] 37 | exclude = ["assets*", "scripts*", "tests*"] 38 | -------------------------------------------------------------------------------- /scripts/extract_mm_projector.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is just a utility that I use to extract the projector for quantized models. 3 | It is NOT necessary at all to train, or run inference/serve demos. 4 | Use this script ONLY if you fully understand its implications. 5 | """ 6 | 7 | 8 | import os 9 | import argparse 10 | import torch 11 | import json 12 | from collections import defaultdict 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser(description='Extract MMProjector weights') 17 | parser.add_argument('--model-path', type=str, help='model folder') 18 | parser.add_argument('--output', type=str, help='output file') 19 | args = parser.parse_args() 20 | return args 21 | 22 | 23 | if __name__ == '__main__': 24 | args = parse_args() 25 | 26 | keys_to_match = ['mm_in_projector', 'mm_out_projector'] 27 | ckpt_to_key = defaultdict(list) 28 | try: 29 | model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json'))) 30 | for k, v in model_indices['weight_map'].items(): 31 | if any(key_match in k for key_match in keys_to_match): 32 | ckpt_to_key[v].append(k) 33 | except FileNotFoundError: 34 | # Smaller models or model checkpoints saved by DeepSpeed. 35 | v = 'pytorch_model.bin' 36 | for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys(): 37 | if any(key_match in k for key_match in keys_to_match): 38 | ckpt_to_key[v].append(k) 39 | 40 | loaded_weights = {} 41 | 42 | for ckpt_name, weight_keys in ckpt_to_key.items(): 43 | ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu') 44 | for k in weight_keys: 45 | loaded_weights[k] = ckpt[k] 46 | 47 | torch.save(loaded_weights, args.output) 48 | -------------------------------------------------------------------------------- /scripts/finetune.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | IMAGE_FOLDER=( 5 | "./data/ALLaVA-4V/" 6 | "./data/llava-150k/images" 7 | "./data/okvqa/train2014" 8 | "./data/okvqa/train2014" 9 | "./data/okvqa/train2014" 10 | "./data/gqa/images" 11 | ) 12 | 13 | DATA_PATH=( 14 | "./data/ALLaVA-4V/allava_laion/ALLaVA-Instruct-LAION-4V_preprocessed.json" 15 | "./data/llava-150k/pandagpt4_visual_instruction_data.json" 16 | "./data/vqa2" 17 | "./data/okvqa" 18 | "./data/okvqa/aokvqa_v1p0_train.json" 19 | "./data/gqa/train_balanced_questions.json" 20 | "" 21 | ) 22 | 23 | DATA_MULTIPLE=( 24 | 1 25 | 1 26 | 1 27 | 1 28 | 1 29 | 1 30 | ) 31 | 32 | DATASET_NAME=( 33 | "ALLaVA-Instruct-LAION-4V" 34 | "LLaVA150K" 35 | "VQAv2" 36 | "OKVQA" 37 | "AOKVQA" 38 | "GQA" 39 | ) 40 | 41 | 42 | IMAGE_FOLDER="${IMAGE_FOLDER[@]}" 43 | DATA_PATH="${DATA_PATH[@]}" 44 | DATASET_NAME="${DATASET_NAME[*]}" 45 | DATA_MULTIPLE="${DATA_MULTIPLE[@]}" 46 | 47 | 48 | deepspeed train_mem.py \ 49 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \ 50 | --deepspeed ./scripts/zero2.json \ 51 | --model_name_or_path ./pretrained_ckpt/vicuna-7b-v1.5 \ 52 | --version v1 \ 53 | --data_path $DATA_PATH \ 54 | --dataset_name $DATASET_NAME \ 55 | --image_folder $IMAGE_FOLDER \ 56 | --data_multiple $DATA_MULTIPLE \ 57 | --vision_tokenizer setok \ 58 | --vision_tower ./pretrained_ckpt/siglip-so400m-patch14-384 \ 59 | --pretrain_vision_tokenizer ./checkpoints/ \ 60 | --pretrain_vision_detokenizer ./checkpoints/ \ 61 | --mm_in_projector_type mlp2x_gelu \ 62 | --tune_mm_in_mlp_adapter False \ 63 | --pretrain_mm_in_mlp_adapter mm_projector.bin \ 64 | --mm_out_projector_type mlp2x_gelu \ 65 | --tune_mm_out_mlp_adapter True \ 66 | --pretrain_mm_out_mlp_adapter mm_projector.bin \ 67 | --mm_vision_select_layer -1 \ 68 | --mm_use_im_start_end True \ 69 | --mm_use_im_patch_token False \ 70 | --feature_mapper_path_or_name ./pretrained_ckpt/bert-base-uncased \ 71 | --bf16 True \ 72 | --output_dir ./checkpoints/ \ 73 | --num_train_epochs 1 \ 74 | --per_device_train_batch_size 32 \ 75 | --per_device_eval_batch_size 4 \ 76 | --gradient_accumulation_steps 1 \ 77 | --evaluation_strategy "no" \ 78 | --save_strategy "steps" \ 79 | --save_steps 4000 \ 80 | --save_total_limit 1 \ 81 | --learning_rate 1e-3 \ 82 | --weight_decay 0. \ 83 | --warmup_ratio 0.03 \ 84 | --lr_scheduler_type "cosine" \ 85 | --logging_steps 1 \ 86 | --tf32 True \ 87 | --model_max_length 2048 \ 88 | --gradient_checkpointing True \ 89 | --dataloader_num_workers 4 \ 90 | --lazy_preprocess True \ 91 | --report_to tensorboard 92 | -------------------------------------------------------------------------------- /scripts/merge_lora_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from src.model.builder import load_pretrained_model 3 | from src.mm_utils import get_model_name_from_path 4 | 5 | 6 | def merge_lora(args): 7 | model_name = get_model_name_from_path(args.model_path) 8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu') 9 | 10 | model.save_pretrained(args.save_model_path) 11 | tokenizer.save_pretrained(args.save_model_path) 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--model-path", type=str, required=True) 17 | parser.add_argument("--model-base", type=str, required=True) 18 | parser.add_argument("--save-model-path", type=str, required=True) 19 | 20 | args = parser.parse_args() 21 | 22 | merge_lora(args) 23 | -------------------------------------------------------------------------------- /scripts/pretrain_mm_proj.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | IMAGE_FOLDER=( 5 | "./data/LLaVA-Instruct-150K/images" 6 | ) 7 | 8 | DATA_PATH=( 9 | "./data/LLaVA-Instruct-150K/llava_instruct_150k.json" 10 | ) 11 | 12 | DATA_MULTIPLE=( 13 | 1 14 | ) 15 | 16 | DATASET_NAME=( 17 | "LLaVA150K" 18 | ) 19 | 20 | 21 | IMAGE_FOLDER="${IMAGE_FOLDER[@]}" 22 | DATA_PATH="${DATA_PATH[@]}" 23 | DATASET_NAME="${DATASET_NAME[*]}" 24 | DATA_MULTIPLE="${DATA_MULTIPLE[@]}" 25 | 26 | 27 | accelerate launch train_mem.py \ 28 | --deepspeed ./scripts/zero2.json \ 29 | --model_name_or_path ./pretrained_ckpt/vicuna-7b-v1.5 \ 30 | --version plain \ 31 | --data_path $DATA_PATH \ 32 | --dataset_name $DATASET_NAME \ 33 | --image_folder $IMAGE_FOLDER \ 34 | --data_multiple $DATA_MULTIPLE \ 35 | --vision_tokenizer setok \ 36 | --vision_tower ./pretrained_ckpt/siglip-so400m-patch14-384 \ 37 | --pretrain_vision_tokenizer ./checkpoints/ \ 38 | --pretrain_vision_detokenizer ./checkpoints/ \ 39 | --mm_in_projector_type mlp2x_gelu \ 40 | --tune_mm_in_mlp_adapter True \ 41 | --mm_out_projector_type mlp2x_gelu \ 42 | --tune_mm_out_mlp_adapter True \ 43 | --mm_vision_select_layer -1 \ 44 | --mm_use_im_start_end True \ 45 | --mm_use_im_patch_token False \ 46 | --feature_mapper_path_or_name ./pretrained_ckpt/bert-base-uncased \ 47 | --bf16 True \ 48 | --output_dir ./checkpoints/ \ 49 | --num_train_epochs 1 \ 50 | --per_device_train_batch_size 32 \ 51 | --per_device_eval_batch_size 4 \ 52 | --gradient_accumulation_steps 1 \ 53 | --evaluation_strategy "no" \ 54 | --save_strategy "steps" \ 55 | --save_steps 4000 \ 56 | --save_total_limit 1 \ 57 | --learning_rate 1e-3 \ 58 | --weight_decay 0. \ 59 | --warmup_ratio 0.03 \ 60 | --lr_scheduler_type "cosine" \ 61 | --logging_steps 1 \ 62 | --tf32 True \ 63 | --model_max_length 2048 \ 64 | --gradient_checkpointing True \ 65 | --dataloader_num_workers 4 \ 66 | --lazy_preprocess True \ 67 | --report_to tensorboard 68 | -------------------------------------------------------------------------------- /scripts/train_setok.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | 6 | IMAGE_FOLDER=( 7 | "./data/cc3m/images" # 5240031 8 | ) 9 | 10 | DATA_PATH=( 11 | "./data/cc3m/cc3m.json" 12 | ) 13 | 14 | DATASET_NAME=( 15 | "cc3m" 16 | ) 17 | 18 | DATA_MULTIPLE=( 19 | 1 20 | ) 21 | 22 | echo $IMAGE_FOLDER 23 | echo $DATA_PATH 24 | echo $DATASET_NAME 25 | 26 | IMAGE_FOLDER="${IMAGE_FOLDER[*]}" 27 | DATA_PATH="${DATA_PATH[*]}" 28 | DATASET_NAME="${DATASET_NAME[*]}" 29 | DATA_MULTIPLE="${DATA_MULTIPLE[*]}" 30 | 31 | 32 | echo $IMAGE_FOLDER 33 | echo $DATA_PATH 34 | echo $DATASET_NAME 35 | 36 | 37 | 38 | deepspeed train_mem.py \ 39 | --deepspeed ./scripts/zero2.json \ 40 | --lora_enable False \ 41 | --data_path $DATA_PATH \ 42 | --image_folder $IMAGE_FOLDER \ 43 | --data_multiple $DATA_MULTIPLE \ 44 | --dataset_name $DATASET_NAME \ 45 | --image_size $IMAGE_SIZE \ 46 | --vision_tower ./pretrained_ckpt/siglip-so400m-patch14-384 \ 47 | --feature_mapper_path_or_name ./pretrained_ckpt/bert-base-uncased \ 48 | --bf16 False \ 49 | --output_dir ./checkpoints/ \ 50 | --num_train_epochs 1 \ 51 | --per_device_train_batch_size 24 \ 52 | --per_device_eval_batch_size 4 \ 53 | --gradient_accumulation_steps 1 \ 54 | --evaluation_strategy "no" \ 55 | --save_strategy "steps" \ 56 | --save_steps 24000 \ 57 | --save_total_limit 1 \ 58 | --learning_rate 1e-3 \ 59 | --weight_decay 0. \ 60 | --warmup_ratio 0.03 \ 61 | --lr_scheduler_type "cosine" \ 62 | --logging_steps 1 \ 63 | --tf32 True \ 64 | --fp16 False \ 65 | --model_max_length 77 \ 66 | --gradient_checkpointing True \ 67 | --dataloader_num_workers 4 \ 68 | --lazy_preprocess True \ 69 | --report_to tensorboard \ 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 2, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto" 22 | } 23 | } -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "train_micro_batch_size_per_gpu": "auto", 14 | "train_batch_size": "auto", 15 | "gradient_accumulation_steps": "auto", 16 | "zero_optimization": { 17 | "stage": 3, 18 | "overlap_comm": true, 19 | "contiguous_gradients": true, 20 | "sub_group_size": 1e9, 21 | "reduce_bucket_size": "auto", 22 | "stage3_prefetch_bucket_size": "auto", 23 | "stage3_param_persistence_threshold": "auto", 24 | "stage3_max_live_parameters": 1e9, 25 | "stage3_max_reuse_distance": 1e9, 26 | "stage3_gather_16bit_weights_on_model_save": true 27 | } 28 | } -------------------------------------------------------------------------------- /scripts/zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupLR", 24 | "params": { 25 | "warmup_min_lr": "auto", 26 | "warmup_max_lr": "auto", 27 | "warmup_num_steps": "auto" 28 | } 29 | }, 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 1e9, 47 | "stage3_max_reuse_distance": 1e9, 48 | "gather_16bit_weights_on_model_save": true 49 | }, 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "train_batch_size": "auto", 53 | "train_micro_batch_size_per_gpu": "auto", 54 | "steps_per_print": 1e5, 55 | "wall_clock_breakdown": false 56 | } -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .model.language_model.setokim_llama import SetokimLlamaForCausalLM 2 | -------------------------------------------------------------------------------- /src/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "." 5 | 6 | # Model Constants 7 | IGNORE_INDEX = -100 8 | IMAGE_TOKEN_INDEX = -200 9 | DEFAULT_IMAGE_TOKEN = "" 10 | DEFAULT_IMAGE_PATCH_TOKEN = "" 11 | DEFAULT_IM_START_TOKEN = "" 12 | DEFAULT_IM_END_TOKEN = "" 13 | IMAGE_PLACEHOLDER = "" 14 | 15 | TARGET_TOKEN_INDEX = -300 16 | DEFAULT_TARGET_TOKEN = "" -------------------------------------------------------------------------------- /src/data_preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from concurrent.futures import ThreadPoolExecutor 4 | from joblib import Parallel, delayed 5 | 6 | from tqdm import tqdm 7 | 8 | 9 | def check_image_exists(data, image_folder): 10 | """ Helper function to check if an image exists and return the data accordingly. """ 11 | if os.path.exists(os.path.join(image_folder, data['image'])): 12 | return (True, data) 13 | else: 14 | return (False, data['image']) 15 | 16 | 17 | def preprocess(params): 18 | data_path, image_folder, save_path = params 19 | 20 | # Load data from JSON file 21 | with open(data_path, 'r') as f: 22 | datas = json.load(f) 23 | 24 | # Create a thread pool executor to check image existence in parallel 25 | new_datas = [] 26 | unexisted_images = [] 27 | with Parallel(n_jobs=50) as parallel: 28 | results = parallel(delayed(check_image_exists)(data, image_folder) for data in tqdm(datas, desc="Checking images")) 29 | 30 | 31 | # Separate the results into new data and unexisted images 32 | for result in results: 33 | exists, data = result 34 | if exists: 35 | new_datas.append(data) 36 | else: 37 | unexisted_images.append(data) 38 | 39 | # Save the filtered data back to a JSON file 40 | with open(save_path, 'w') as f: 41 | json.dump(new_datas, f, indent=4) 42 | 43 | # Print out the unexisted images 44 | print(f'Unexisted images: {unexisted_images}') 45 | 46 | 47 | 48 | if __name__ == "__main__": 49 | data_path = './ALLaVA-4V/allava_laion/ALLaVA-Instruct-LAION-4V.json' 50 | image_folder = './ALLaVA-4V' 51 | save_path = './ALLaVA-4V/allava_laion/ALLaVA-Instruct-LAION-4V_preprocessed.json' 52 | 53 | import torch 54 | pretrain_mm_mlp_adapter = './checkpoints/vicuna-v1.5-7b-convnext-pretrain/mm_projector.bin' 55 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') 56 | def get_w(weights, keyword): 57 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} 58 | 59 | res = get_w(mm_projector_weights, 'mm_projector') -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .editDataset import InstructPix2Pix_Dataset, MagicBrush_Dataset, EditingDataset 2 | from .pairDataset import TextImagePairDataset 3 | from .instructDataset import InstructionTuningDataset 4 | from .base_dataset import DataCollatorForSupervisedDataset -------------------------------------------------------------------------------- /src/dataset/dataset_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import random 4 | from PIL import Image 5 | 6 | 7 | def extend_list(original_list, multiplier): 8 | # Calculate how many elements to replicate and how many to select randomly 9 | replicate_elements = math.floor(multiplier) 10 | random_elements = multiplier - replicate_elements 11 | 12 | # Replicate the list 13 | replicated_list = original_list * replicate_elements 14 | 15 | # Calculate how many elements to randomly select 16 | select_elements = math.ceil(len(original_list) * random_elements) 17 | 18 | # Randomly select elements and append to the replicated list 19 | for _ in range(select_elements): 20 | random_element = random.choice(original_list) 21 | replicated_list.append(random_element) 22 | 23 | return replicated_list 24 | 25 | 26 | def expand2square(pil_img, background_color): 27 | width, height = pil_img.size 28 | if width == height: 29 | return pil_img 30 | elif width > height: 31 | result = Image.new(pil_img.mode, (width, width), background_color) 32 | result.paste(pil_img, (0, (width - height) // 2)) 33 | return result 34 | else: 35 | result = Image.new(pil_img.mode, (height, height), background_color) 36 | result.paste(pil_img, ((height - width) // 2, 0)) 37 | return result -------------------------------------------------------------------------------- /src/dataset/editDataset.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import token 3 | 4 | from datasets import load_from_disk 5 | import io 6 | import numpy as np 7 | from PIL import Image 8 | import random 9 | import torch 10 | from torchvision import transforms 11 | from torch.utils.data import Dataset, ConcatDataset 12 | from .base_dataset import * 13 | from .dataset_utils import expand2square 14 | 15 | 16 | def convert_to_np(image, resolution): 17 | image = image.convert("RGB") 18 | image = image.resize((resolution, resolution), resample=Image.Resampling.BICUBIC) 19 | return np.array(image).transpose(2, 0, 1) 20 | 21 | 22 | def load_img_for_generator(image, resolution): 23 | # image = Image.open(path).convert("RGB") 24 | # w, h = image.size 25 | # print(f"loaded input image of size ({w}, {h}) from {path}") 26 | # w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 27 | image = image.resize((resolution), resample=Image.Resampling.LANCZOS) 28 | image = np.array(image).astype(np.float32) / 255.0 29 | image = image.transpose(2, 0, 1) 30 | image = torch.from_numpy(image) 31 | return 2.*image - 1. 32 | 33 | 34 | def get_random_response(): 35 | image_editing_responses = { 36 | "simple": [ 37 | "Here you go.", 38 | "All set.", 39 | "Done.", 40 | "Here it is.", 41 | "Finished.", 42 | "Done. Let me know if it works!" 43 | ], 44 | "polite_professional": [ 45 | "The image has been edited as requested. Please take a look.", 46 | "Here is the updated version of the image you asked for.", 47 | "Attached is the revised image. Let me know if everything looks good.", 48 | "I've completed the edits—feel free to review.", 49 | "Please find the edited image below. Let me know if you'd like any revisions." 50 | ], 51 | "casual_friendly": [ 52 | "All done! Hope you like it.", 53 | "Tada 🎨 Let me know what you think!", 54 | "Voila! Here's your image.", 55 | "Here's the new version—check it out!", 56 | "Done and dusted 😎", 57 | "Boom! Updated and ready." 58 | ], 59 | "open_to_feedback": [ 60 | "Let me know if you'd like to adjust anything else.", 61 | "Happy to make further edits if needed!", 62 | "If you need a different version, just say the word.", 63 | "Want to tweak anything? I've got you.", 64 | "Tell me if something needs changing!" 65 | ], 66 | "image_generation_context": [ 67 | "Here is the image based on your description.", 68 | "The generated image is ready. Let me know if it matches your vision.", 69 | "Here's what I came up with—does this align with what you had in mind?", 70 | "Based on your prompt, this is the result. Happy to revise!" 71 | ] 72 | } 73 | 74 | all_responses = sum(image_editing_responses.values(), []) 75 | 76 | random_reply = random.choice(all_responses) 77 | return random_reply 78 | 79 | 80 | 81 | class EditingDataset(Dataset): 82 | def __init__(self, data_path, tokenizer, data_args) -> None: 83 | super().__init__() 84 | 85 | instructPix2Pix_dataset = InstructPix2Pix_Dataset(data_path[0], tokenizer=tokenizer, data_args=data_args) 86 | magicBruch_dataset = MagicBrush_Dataset(data_path=data_path[1], tokenizer=tokenizer, data_args=data_args) 87 | self.datasets = ConcatDataset([instructPix2Pix_dataset, magicBruch_dataset]) 88 | 89 | def __len__(self): 90 | return self.datasets.__len__() 91 | 92 | def __getitem__(self, item): 93 | return self.datasets.__getitem__(item) 94 | 95 | 96 | 97 | # InstructPix2Pix dataset 98 | class InstructPix2Pix_Dataset(LazySupervisedDataset): 99 | ''' 100 | according to InstructPix2Pix, the dataset can be used to train models to follow edit instructions. 101 | Edit instructions are available in the 'edit_prompt'. 'original_image' can be used with the 'edit_prompt' and 'edited_image' denotes the image after applying the 'edit_prompt' on the 'original_image'. 102 | "original_image" + "edited_image" + "edit_prompt" 103 | ''' 104 | def __init__(self, 105 | data_path, 106 | tokenizer, 107 | data_args, 108 | ): 109 | super().__init__(data_args=data_args, data_path=data_path, tokenizer=tokenizer) 110 | 111 | # InstructPix2Pix Dataset path 112 | self.list_data_dict = load_from_disk(data_path) 113 | # 224, 256 114 | self.resolution_for_comp = data_args.image_size 115 | self.resolution_for_gen = data_args.resolution_sd 116 | 117 | # tokenizer 118 | self.tokenizer = tokenizer 119 | 120 | 121 | def __len__(self,): 122 | return len(self.list_data_dict) 123 | 124 | def __getitem__(self, i): 125 | # # {'original_image': , 'edited_image': , 'edit_prompt': 'make the leaves yellow'} 126 | sources = self.list_data_dict[i] 127 | if isinstance(i, int): 128 | sources = [sources] 129 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 130 | if 'image' in sources[0]: 131 | original_image_file = self.list_data_dict[i]['original_image'] 132 | # image_folder = self.data_args.image_folder 133 | processor = self.data_args.image_processor 134 | # image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') 135 | original_image = Image.open(io.BytesIO(original_image_file['bytes'])).convert('RGB') 136 | edited_image_file = self.list_data_dict[i]['edited_image'] 137 | edited_image = Image.open(io.BytesIO(edited_image_file['bytes'])).convert('RGB') 138 | if self.data_args.image_aspect_ratio == 'pad': 139 | def expand2square(pil_img, background_color): 140 | width, height = pil_img.size 141 | if width == height: 142 | return pil_img 143 | elif width > height: 144 | result = Image.new(pil_img.mode, (width, width), background_color) 145 | result.paste(pil_img, (0, (width - height) // 2)) 146 | return result 147 | else: 148 | result = Image.new(pil_img.mode, (height, height), background_color) 149 | result.paste(pil_img, ((height - width) // 2, 0)) 150 | return result 151 | original_image = expand2square(original_image, tuple(int(x*255) for x in processor.image_mean)) 152 | comp_image = processor.preprocess(original_image, return_tensors='pt')['pixel_values'][0] 153 | gen_image = load_img_for_generator(edited_image, self.resolution_for_gen) 154 | else: 155 | comp_image = processor.preprocess(original_image, return_tensors='pt')['pixel_values'][0] 156 | gen_image = load_img_for_generator(edited_image, self.resolution_for_gen) 157 | 158 | _source = { 159 | "id": i, 160 | "conversations": [ 161 | {"from": "human", "value": "\n"+self.list_data_dict[i]["edit_prompt"]}, 162 | {"from": "gpt", "value": "\n"+ get_random_response()}, 163 | ] 164 | } 165 | sources = preprocess_multimodal( 166 | [_source], 167 | self.data_args, 168 | target_num=gen_image.shape[0]) 169 | else: 170 | sources = [_source] 171 | data_dict = preprocess( 172 | sources, 173 | self.tokenizer, 174 | has_image=('image' in self.list_data_dict[i])) 175 | if isinstance(i, int): 176 | data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) 177 | 178 | # image exist in the data 179 | if 'image' in self.list_data_dict[i]: 180 | data_dict['comp_image'] = comp_image 181 | data_dict['gen_image'] = gen_image 182 | elif self.data_args.is_multimodal: 183 | # image does not exist in the data, but the model is multimodal 184 | crop_size = self.data_args.image_processor.crop_size 185 | data_dict['comp_image'] = torch.zeros(3, crop_size['height'], crop_size['width']) 186 | data_dict['gen_image'] = torch.zeros(3, crop_size['height'], crop_size['width']) 187 | 188 | return data_dict 189 | 190 | 191 | 192 | # MagicBrush dataset 193 | class MagicBrush_Dataset(LazySupervisedDataset): 194 | ''' 195 | according to MagicBrush, the dataset can be used to train models to follow edit instructions. 196 | Edit instructions are available in the 'instruction'. 'source_img' can be used with the 'instruction' and 'target_img' denotes the image after applying the 'instruction' on the 'source_img'. 197 | "source_img" + "target_img" + "instruction" 198 | Dataset({features: ['img_id', 'turn_index', 'source_img', 'mask_img', 'instruction', 'target_img'], num_rows: 8807}) 199 | ''' 200 | def __init__(self, 201 | data_path, 202 | tokenizer, 203 | data_args, 204 | ): 205 | super().__init__(data_path=data_path, tokenizer=tokenizer, data_args=data_args) 206 | # MagicBrush Dataset path 207 | # InstructPix2Pix Dataset path 208 | self.list_data_dict = load_from_disk(data_path) 209 | # 224, 256 210 | self.resolution_for_comp = data_args.image_size 211 | self.resolution_for_gen = data_args.resolution_sd 212 | 213 | # tokenizer 214 | self.tokenizer = tokenizer 215 | 216 | def __len__(self,): 217 | return len(self.list_data_dict) 218 | 219 | def __getitem__(self, i): 220 | # {'source_img': , 'target_img': , 'instruction': 'let the asparagus be replaced with sausages'} 221 | 222 | sources = self.list_data_dict[i] 223 | if isinstance(i, int): 224 | sources = [sources] 225 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME 226 | if 'image' in sources[0]: 227 | original_image_file = self.list_data_dict[i]['source_img'] 228 | # image_folder = self.data_args.image_folder 229 | processor = self.data_args.image_processor 230 | # image = Image.open(os.path.join(image_folder, image_file)).convert('RGB') 231 | original_image = Image.open(io.BytesIO(original_image_file['bytes'])).convert('RGB') 232 | edited_image_file = self.list_data_dict[i]['target_img'] 233 | edited_image = Image.open(io.BytesIO(edited_image_file['bytes'])).convert('RGB') 234 | if self.data_args.image_aspect_ratio == 'pad': 235 | def expand2square(pil_img, background_color): 236 | width, height = pil_img.size 237 | if width == height: 238 | return pil_img 239 | elif width > height: 240 | result = Image.new(pil_img.mode, (width, width), background_color) 241 | result.paste(pil_img, (0, (width - height) // 2)) 242 | return result 243 | else: 244 | result = Image.new(pil_img.mode, (height, height), background_color) 245 | result.paste(pil_img, ((height - width) // 2, 0)) 246 | return result 247 | original_image = expand2square(original_image, tuple(int(x*255) for x in processor.image_mean)) 248 | comp_image = processor.preprocess(original_image, return_tensors='pt')['pixel_values'][0] 249 | gen_image = load_img_for_generator(edited_image, self.resolution_for_gen) 250 | else: 251 | comp_image = processor.preprocess(original_image, return_tensors='pt')['pixel_values'][0] 252 | gen_image = load_img_for_generator(edited_image, self.resolution_for_gen) 253 | 254 | _source = { 255 | "id": i, 256 | "conversations": [ 257 | {"from": "human", "value": "\n"+self.list_data_dict[i]["instruction"]}, 258 | {"from": "gpt", "value": "\n"+ get_random_response()}, 259 | ] 260 | } 261 | sources = preprocess_multimodal( 262 | [_source], 263 | self.data_args, 264 | target_num=gen_image.shape[0]) 265 | else: 266 | sources = [_source] 267 | data_dict = preprocess( 268 | sources, 269 | self.tokenizer, 270 | has_image=('image' in self.list_data_dict[i])) 271 | if isinstance(i, int): 272 | data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0]) 273 | 274 | # image exist in the data 275 | if 'image' in self.list_data_dict[i]: 276 | data_dict['comp_image'] = comp_image 277 | data_dict['gen_image'] = gen_image 278 | elif self.data_args.is_multimodal: 279 | # image does not exist in the data, but the model is multimodal 280 | crop_size = self.data_args.image_processor.crop_size 281 | data_dict['comp_image'] = torch.zeros(3, crop_size['height'], crop_size['width']) 282 | data_dict['gen_image'] = torch.zeros(3, crop_size['height'], crop_size['width']) 283 | 284 | return data_dict -------------------------------------------------------------------------------- /src/dataset/vqa.py: -------------------------------------------------------------------------------- 1 | __author__ = 'aagrawal' 2 | __version__ = '0.9' 3 | 4 | # Interface for accessing the VQA dataset. 5 | 6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link: 7 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py). 8 | 9 | # The following functions are defined: 10 | # VQA - VQA class that loads VQA annotation file and prepares data structures. 11 | # getQuesIds - Get question ids that satisfy given filter conditions. 12 | # getImgIds - Get image ids that satisfy given filter conditions. 13 | # loadQA - Load questions and answers with the specified question ids. 14 | # showQA - Display the specified questions and answers. 15 | # loadRes - Load result file and create result object. 16 | 17 | # Help on each function can be accessed by: "help(COCO.function)" 18 | 19 | import json 20 | import datetime 21 | import copy 22 | 23 | class VQA: 24 | def __init__(self, annotation_file=None, question_file=None): 25 | """ 26 | Constructor of VQA helper class for reading and visualizing questions and answers. 27 | :param annotation_file (str): location of VQA annotation file 28 | :return: 29 | """ 30 | # load dataset 31 | self.dataset = {} 32 | self.questions = {} 33 | self.qa = {} 34 | self.qqa = {} 35 | self.imgToQA = {} 36 | if not annotation_file == None and not question_file == None: 37 | print('loading VQA annotations and questions into memory...') 38 | time_t = datetime.datetime.utcnow() 39 | dataset = json.load(open(annotation_file, 'r')) 40 | questions = json.load(open(question_file, 'r')) 41 | print(datetime.datetime.utcnow() - time_t) 42 | self.dataset = dataset 43 | self.questions = questions 44 | self.createIndex() 45 | 46 | def createIndex(self): 47 | # create index 48 | print('creating index...') 49 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']} 50 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']} 51 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']} 52 | for ann in self.dataset['annotations']: 53 | imgToQA[ann['image_id']] += [ann] 54 | qa[ann['question_id']] = ann 55 | for ques in self.questions['questions']: 56 | qqa[ques['question_id']] = ques 57 | print('index created!') 58 | 59 | # create class members 60 | self.qa = qa 61 | self.qqa = qqa 62 | self.imgToQA = imgToQA 63 | 64 | def info(self): 65 | """ 66 | Print information about the VQA annotation file. 67 | :return: 68 | """ 69 | for key, value in self.datset['info'].items(): 70 | print('%s: %s'%(key, value)) 71 | 72 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]): 73 | """ 74 | Get question ids that satisfy given filter conditions. default skips that filter 75 | :param imgIds (int array) : get question ids for given imgs 76 | quesTypes (str array) : get question ids for given question types 77 | ansTypes (str array) : get question ids for given answer types 78 | :return: ids (int array) : integer array of question ids 79 | """ 80 | imgIds = imgIds if type(imgIds) == list else [imgIds] 81 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 82 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 83 | 84 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0: 85 | anns = self.dataset['annotations'] 86 | else: 87 | if not len(imgIds) == 0: 88 | anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],[]) 89 | else: 90 | anns = self.dataset['annotations'] 91 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 92 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 93 | ids = [ann['question_id'] for ann in anns] 94 | return ids 95 | 96 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]): 97 | """ 98 | Get image ids that satisfy given filter conditions. default skips that filter 99 | :param quesIds (int array) : get image ids for given question ids 100 | quesTypes (str array) : get image ids for given question types 101 | ansTypes (str array) : get image ids for given answer types 102 | :return: ids (int array) : integer array of image ids 103 | """ 104 | quesIds = quesIds if type(quesIds) == list else [quesIds] 105 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes] 106 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes] 107 | 108 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0: 109 | anns = self.dataset['annotations'] 110 | else: 111 | if not len(quesIds) == 0: 112 | anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa],[]) 113 | else: 114 | anns = self.dataset['annotations'] 115 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes] 116 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes] 117 | ids = [ann['image_id'] for ann in anns] 118 | return ids 119 | 120 | def loadQA(self, ids=[]): 121 | """ 122 | Load questions and answers with the specified question ids. 123 | :param ids (int array) : integer ids specifying question ids 124 | :return: qa (object array) : loaded qa objects 125 | """ 126 | if type(ids) == list: 127 | return [self.qa[id] for id in ids] 128 | elif type(ids) == int: 129 | return [self.qa[ids]] 130 | 131 | def showQA(self, anns): 132 | """ 133 | Display the specified annotations. 134 | :param anns (array of object): annotations to display 135 | :return: None 136 | """ 137 | if len(anns) == 0: 138 | return 0 139 | for ann in anns: 140 | quesId = ann['question_id'] 141 | print("Question: %s" %(self.qqa[quesId]['question'])) 142 | for ans in ann['answers']: 143 | print("Answer %d: %s" %(ans['answer_id'], ans['answer'])) 144 | 145 | def loadRes(self, resFile, quesFile): 146 | """ 147 | Load result file and return a result object. 148 | :param resFile (str) : file name of result file 149 | :return: res (obj) : result api object 150 | """ 151 | res = VQA() 152 | res.questions = json.load(open(quesFile)) 153 | res.dataset['info'] = copy.deepcopy(self.questions['info']) 154 | res.dataset['task_type'] = copy.deepcopy(self.questions['task_type']) 155 | res.dataset['data_type'] = copy.deepcopy(self.questions['data_type']) 156 | res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype']) 157 | res.dataset['license'] = copy.deepcopy(self.questions['license']) 158 | 159 | print('Loading and preparing results... ') 160 | time_t = datetime.datetime.utcnow() 161 | anns = json.load(open(resFile)) 162 | assert type(anns) == list, 'results is not an array of objects' 163 | annsQuesIds = [ann['question_id'] for ann in anns] 164 | assert set(annsQuesIds) == set(self.getQuesIds()), \ 165 | 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.' 166 | for ann in anns: 167 | quesId = ann['question_id'] 168 | if res.dataset['task_type'] == 'Multiple Choice': 169 | assert ann['answer'] in self.qqa[quesId]['multiple_choices'], 'predicted answer is not one of the multiple choices' 170 | qaAnn = self.qa[quesId] 171 | ann['image_id'] = qaAnn['image_id'] 172 | ann['question_type'] = qaAnn['question_type'] 173 | ann['answer_type'] = qaAnn['answer_type'] 174 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds())) 175 | 176 | res.dataset['annotations'] = anns 177 | res.createIndex() 178 | return res -------------------------------------------------------------------------------- /src/mm_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import base64 4 | import torch 5 | import math 6 | import ast 7 | 8 | from transformers import StoppingCriteria 9 | from src.constants import TARGET_TOKEN_INDEX, IMAGE_TOKEN_INDEX 10 | 11 | 12 | def select_best_resolution(original_size, possible_resolutions): 13 | """ 14 | Selects the best resolution from a list of possible resolutions based on the original size. 15 | 16 | Args: 17 | original_size (tuple): The original size of the image in the format (width, height). 18 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. 19 | 20 | Returns: 21 | tuple: The best fit resolution in the format (width, height). 22 | """ 23 | original_width, original_height = original_size 24 | best_fit = None 25 | max_effective_resolution = 0 26 | min_wasted_resolution = float('inf') 27 | 28 | for width, height in possible_resolutions: 29 | scale = min(width / original_width, height / original_height) 30 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) 31 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) 32 | wasted_resolution = (width * height) - effective_resolution 33 | 34 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): 35 | max_effective_resolution = effective_resolution 36 | min_wasted_resolution = wasted_resolution 37 | best_fit = (width, height) 38 | 39 | return best_fit 40 | 41 | 42 | def resize_and_pad_image(image, target_resolution): 43 | """ 44 | Resize and pad an image to a target resolution while maintaining aspect ratio. 45 | 46 | Args: 47 | image (PIL.Image.Image): The input image. 48 | target_resolution (tuple): The target resolution (width, height) of the image. 49 | 50 | Returns: 51 | PIL.Image.Image: The resized and padded image. 52 | """ 53 | original_width, original_height = image.size 54 | target_width, target_height = target_resolution 55 | 56 | scale_w = target_width / original_width 57 | scale_h = target_height / original_height 58 | 59 | if scale_w < scale_h: 60 | new_width = target_width 61 | new_height = min(math.ceil(original_height * scale_w), target_height) 62 | else: 63 | new_height = target_height 64 | new_width = min(math.ceil(original_width * scale_h), target_width) 65 | 66 | # Resize the image 67 | resized_image = image.resize((new_width, new_height)) 68 | 69 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) 70 | paste_x = (target_width - new_width) // 2 71 | paste_y = (target_height - new_height) // 2 72 | new_image.paste(resized_image, (paste_x, paste_y)) 73 | 74 | return new_image 75 | 76 | 77 | def divide_to_patches(image, patch_size): 78 | """ 79 | Divides an image into patches of a specified size. 80 | 81 | Args: 82 | image (PIL.Image.Image): The input image. 83 | patch_size (int): The size of each patch. 84 | 85 | Returns: 86 | list: A list of PIL.Image.Image objects representing the patches. 87 | """ 88 | patches = [] 89 | width, height = image.size 90 | for i in range(0, height, patch_size): 91 | for j in range(0, width, patch_size): 92 | box = (j, i, j + patch_size, i + patch_size) 93 | patch = image.crop(box) 94 | patches.append(patch) 95 | 96 | return patches 97 | 98 | 99 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): 100 | """ 101 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution. 102 | 103 | Args: 104 | image_size (tuple): The size of the input image in the format (width, height). 105 | grid_pinpoints (str): A string representation of a list of possible resolutions. 106 | patch_size (int): The size of each image patch. 107 | 108 | Returns: 109 | tuple: The shape of the image patch grid in the format (width, height). 110 | """ 111 | if type(grid_pinpoints) is list: 112 | possible_resolutions = grid_pinpoints 113 | else: 114 | possible_resolutions = ast.literal_eval(grid_pinpoints) 115 | width, height = select_best_resolution(image_size, possible_resolutions) 116 | return width // patch_size, height // patch_size 117 | 118 | 119 | def process_anyres_image(image, processor, grid_pinpoints): 120 | """ 121 | Process an image with variable resolutions. 122 | 123 | Args: 124 | image (PIL.Image.Image): The input image to be processed. 125 | processor: The image processor object. 126 | grid_pinpoints (str): A string representation of a list of possible resolutions. 127 | 128 | Returns: 129 | torch.Tensor: A tensor containing the processed image patches. 130 | """ 131 | if type(grid_pinpoints) is list: 132 | possible_resolutions = grid_pinpoints 133 | else: 134 | possible_resolutions = ast.literal_eval(grid_pinpoints) 135 | best_resolution = select_best_resolution(image.size, possible_resolutions) 136 | image_padded = resize_and_pad_image(image, best_resolution) 137 | 138 | patches = divide_to_patches(image_padded, processor.crop_size['height']) 139 | 140 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) 141 | 142 | image_patches = [image_original_resize] + patches 143 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] 144 | for image_patch in image_patches] 145 | return torch.stack(image_patches, dim=0) 146 | 147 | 148 | def load_image_from_base64(image): 149 | return Image.open(BytesIO(base64.b64decode(image))) 150 | 151 | 152 | def expand2square(pil_img, background_color): 153 | width, height = pil_img.size 154 | if width == height: 155 | return pil_img 156 | elif width > height: 157 | result = Image.new(pil_img.mode, (width, width), background_color) 158 | result.paste(pil_img, (0, (width - height) // 2)) 159 | return result 160 | else: 161 | result = Image.new(pil_img.mode, (height, height), background_color) 162 | result.paste(pil_img, ((height - width) // 2, 0)) 163 | return result 164 | 165 | 166 | def process_images(images, image_processor, model_cfg): 167 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) 168 | new_images = [] 169 | if image_aspect_ratio == 'pad': 170 | for image in images: 171 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) 172 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] 173 | new_images.append(image) 174 | elif image_aspect_ratio == "anyres": 175 | for image in images: 176 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) 177 | new_images.append(image) 178 | else: 179 | return image_processor(images, return_tensors='pt')['pixel_values'] 180 | if all(x.shape == new_images[0].shape for x in new_images): 181 | new_images = torch.stack(new_images, dim=0) 182 | return new_images 183 | 184 | 185 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): 186 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] 187 | 188 | def insert_separator(X, sep): 189 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] 190 | 191 | input_ids = [] 192 | offset = 0 193 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: 194 | offset = 1 195 | input_ids.append(prompt_chunks[0][0]) 196 | 197 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): 198 | input_ids.extend(x[offset:]) 199 | 200 | if return_tensors is not None: 201 | if return_tensors == 'pt': 202 | return torch.tensor(input_ids, dtype=torch.long) 203 | raise ValueError(f'Unsupported tensor type: {return_tensors}') 204 | return input_ids 205 | 206 | 207 | def tokenizer_multiple_token(prompt, tokenizer, target_token_indenx=TARGET_TOKEN_INDEX, return_tensors=None): 208 | 209 | input_ids = [] 210 | target_chunks = prompt.split("") 211 | for target_idx, target_ck in enumerate(target_chunks): 212 | _inputs = tokenizer_image_token(target_ck, tokenizer, IMAGE_TOKEN_INDEX, return_tensors=return_tensors) 213 | input_ids.extend(_inputs) 214 | if target_idx < len(target_ck) - 1: 215 | input_ids.append(TARGET_TOKEN_INDEX) 216 | 217 | if return_tensors is not None: 218 | if return_tensors == "pt": 219 | return torch.tensor(input_ids, dtype=torch.long) 220 | raise ValueError(f"Unsupported tensor type: {return_tensors}") 221 | 222 | return input_ids 223 | 224 | def get_model_name_from_path(model_path): 225 | model_path = model_path.strip("/") 226 | model_paths = model_path.split("/") 227 | if model_paths[-1].startswith('checkpoint-'): 228 | return model_paths[-2] + "_" + model_paths[-1] 229 | else: 230 | return model_paths[-1] 231 | 232 | class KeywordsStoppingCriteria(StoppingCriteria): 233 | def __init__(self, keywords, tokenizer, input_ids): 234 | self.keywords = keywords 235 | self.keyword_ids = [] 236 | self.max_keyword_len = 0 237 | for keyword in keywords: 238 | cur_keyword_ids = tokenizer(keyword).input_ids 239 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: 240 | cur_keyword_ids = cur_keyword_ids[1:] 241 | if len(cur_keyword_ids) > self.max_keyword_len: 242 | self.max_keyword_len = len(cur_keyword_ids) 243 | self.keyword_ids.append(torch.tensor(cur_keyword_ids)) 244 | self.tokenizer = tokenizer 245 | self.start_len = input_ids.shape[1] 246 | 247 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 248 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) 249 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] 250 | for keyword_id in self.keyword_ids: 251 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] 252 | if torch.equal(truncated_output_ids, keyword_id): 253 | return True 254 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] 255 | for keyword in self.keywords: 256 | if keyword in outputs: 257 | return True 258 | return False 259 | 260 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 261 | outputs = [] 262 | for i in range(output_ids.shape[0]): 263 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) 264 | return all(outputs) 265 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from .language_model.setokim_llama import SetokimLlamaForCausalLM, SetokimConfig 3 | from .setok.model import SeTok 4 | except: 5 | pass 6 | -------------------------------------------------------------------------------- /src/model/apply_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from src import SetokimLlamaForCausalLM 11 | 12 | 13 | def apply_delta(base_model_path, target_model_path, delta_path): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading delta") 19 | delta = SetokimLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path) 21 | 22 | print("Applying delta") 23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data += base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \ 31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 32 | bparam = base.state_dict()[name] 33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam 34 | 35 | print("Saving target model") 36 | delta.save_pretrained(target_model_path) 37 | delta_tokenizer.save_pretrained(target_model_path) 38 | 39 | 40 | if __name__ == "__main__": 41 | parser = argparse.ArgumentParser() 42 | parser.add_argument("--base-model-path", type=str, required=True) 43 | parser.add_argument("--target-model-path", type=str, required=True) 44 | parser.add_argument("--delta-path", type=str, required=True) 45 | 46 | args = parser.parse_args() 47 | 48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path) 49 | -------------------------------------------------------------------------------- /src/model/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Haotian Liu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | import warnings 18 | import shutil 19 | 20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig 21 | import torch 22 | from src.model import * 23 | from src.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN 24 | 25 | 26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs): 27 | kwargs = {"device_map": device_map, **kwargs} 28 | 29 | if device != "cuda": 30 | kwargs['device_map'] = {"": device} 31 | 32 | if load_8bit: 33 | kwargs['load_in_8bit'] = True 34 | elif load_4bit: 35 | kwargs['load_in_4bit'] = True 36 | kwargs['quantization_config'] = BitsAndBytesConfig( 37 | load_in_4bit=True, 38 | bnb_4bit_compute_dtype=torch.float16, 39 | bnb_4bit_use_double_quant=True, 40 | bnb_4bit_quant_type='nf4' 41 | ) 42 | else: 43 | kwargs['torch_dtype'] = torch.float16 44 | 45 | if use_flash_attn: 46 | kwargs['attn_implementation'] = 'flash_attention_2' 47 | 48 | if 'setokim' in model_name.lower(): 49 | # Load Setokim model 50 | if 'lora' in model_name.lower() and model_base is None: 51 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/Setokim#launch-a-model-worker-lora-weights-unmerged.') 52 | if 'lora' in model_name.lower() and model_base is not None: 53 | from src.model.language_model.setokim_llama import SetokimConfig 54 | lora_cfg_pretrained = SetokimConfig.from_pretrained(model_path) 55 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 56 | print('Loading Setokim from base model...') 57 | model = SetokimLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) 58 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features 59 | if model.lm_head.weight.shape[0] != token_num: 60 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 61 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) 62 | 63 | print('Loading additional Setokim weights...') 64 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): 65 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') 66 | else: 67 | # this is probably from HF Hub 68 | from huggingface_hub import hf_hub_download 69 | def load_from_hf(repo_id, filename, subfolder=None): 70 | cache_file = hf_hub_download( 71 | repo_id=repo_id, 72 | filename=filename, 73 | subfolder=subfolder) 74 | return torch.load(cache_file, map_location='cpu') 75 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') 76 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} 77 | if any(k.startswith('model.model.') for k in non_lora_trainables): 78 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} 79 | model.load_state_dict(non_lora_trainables, strict=False) 80 | 81 | from peft import PeftModel 82 | print('Loading LoRA weights...') 83 | model = PeftModel.from_pretrained(model, model_path) 84 | print('Merging LoRA weights...') 85 | model = model.merge_and_unload() 86 | print('Model is loaded...') 87 | elif model_base is not None: 88 | # this may be mm projector only 89 | print('Loading Setokim from base model...') 90 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 91 | cfg_pretrained = AutoConfig.from_pretrained(model_path) 92 | model = SetokimLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs) 93 | 94 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu') 95 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()} 96 | model.load_state_dict(mm_projector_weights, strict=False) 97 | else: 98 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 99 | model = SetokimLlamaForCausalLM.from_pretrained( 100 | model_path, 101 | low_cpu_mem_usage=True, 102 | **kwargs 103 | ) 104 | else: 105 | # Load language model 106 | if model_base is not None: 107 | # PEFT model 108 | from peft import PeftModel 109 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) 110 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs) 111 | print(f"Loading LoRA weights from {model_path}") 112 | model = PeftModel.from_pretrained(model, model_path) 113 | print(f"Merging weights") 114 | model = model.merge_and_unload() 115 | print('Convert to FP16...') 116 | model.to(torch.float16) 117 | else: 118 | use_fast = False 119 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False) 120 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs) 121 | 122 | image_processor = None 123 | 124 | if 'Setokim' in model_name.lower(): 125 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False) 126 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True) 127 | if mm_use_im_patch_token: 128 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) 129 | if mm_use_im_start_end: 130 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True) 131 | model.resize_token_embeddings(len(tokenizer)) 132 | 133 | vision_tower = model.get_vision_tower() 134 | if not vision_tower.is_loaded: 135 | vision_tower.load_model(device_map=device_map) 136 | if device_map != 'auto': 137 | vision_tower.to(device=device_map, dtype=torch.float16) 138 | image_processor = vision_tower.image_processor 139 | 140 | if hasattr(model.config, "max_sequence_length"): 141 | context_len = model.config.max_sequence_length 142 | else: 143 | context_len = 2048 144 | 145 | return tokenizer, model, image_processor, context_len 146 | -------------------------------------------------------------------------------- /src/model/consolidate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from transformers import AutoTokenizer, AutoModelForCausalLM 9 | from src.model import * 10 | from src.model.utils import auto_upgrade 11 | 12 | 13 | def consolidate_ckpt(src_path, dst_path): 14 | print("Loading model") 15 | auto_upgrade(src_path) 16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False) 18 | src_model.save_pretrained(dst_path) 19 | src_tokenizer.save_pretrained(dst_path) 20 | 21 | 22 | if __name__ == "__main__": 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("--src", type=str, required=True) 25 | parser.add_argument("--dst", type=str, required=True) 26 | 27 | args = parser.parse_args() 28 | 29 | consolidate_ckpt(args.src, args.dst) 30 | -------------------------------------------------------------------------------- /src/model/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Adopted from DiT, which is modified from OpenAI's diffusion repos 2 | # DiT: https://github.com/facebookresearch/DiT/diffusion 3 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 4 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 5 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 6 | 7 | from . import gaussian_diffusion as gd 8 | from .respace import SpacedDiffusion, space_timesteps 9 | 10 | 11 | def create_diffusion( 12 | timestep_respacing, 13 | noise_schedule="linear", 14 | use_kl=False, 15 | sigma_small=False, 16 | predict_xstart=False, 17 | learn_sigma=True, 18 | rescale_learned_sigmas=False, 19 | diffusion_steps=1000 20 | ): 21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 22 | if use_kl: 23 | loss_type = gd.LossType.RESCALED_KL 24 | elif rescale_learned_sigmas: 25 | loss_type = gd.LossType.RESCALED_MSE 26 | else: 27 | loss_type = gd.LossType.MSE 28 | if timestep_respacing is None or timestep_respacing == "": 29 | timestep_respacing = [diffusion_steps] 30 | return SpacedDiffusion( 31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 32 | betas=betas, 33 | model_mean_type=( 34 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 35 | ), 36 | model_var_type=( 37 | ( 38 | gd.ModelVarType.FIXED_LARGE 39 | if not sigma_small 40 | else gd.ModelVarType.FIXED_SMALL 41 | ) 42 | if not learn_sigma 43 | else gd.ModelVarType.LEARNED_RANGE 44 | ), 45 | loss_type=loss_type 46 | # rescale_timesteps=rescale_timesteps, 47 | ) 48 | -------------------------------------------------------------------------------- /src/model/diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a Gaussian distribution discretizing to a 50 | given image. 51 | :param x: the target images. It is assumed that this was uint8 values, 52 | rescaled to the range [-1, 1]. 53 | :param means: the Gaussian mean Tensor. 54 | :param log_scales: the Gaussian log stddev Tensor. 55 | :return: a tensor like x of log probabilities (in nats). 56 | """ 57 | assert x.shape == means.shape == log_scales.shape 58 | centered_x = x - means 59 | inv_stdv = th.exp(-log_scales) 60 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 61 | cdf_plus = approx_standard_normal_cdf(plus_in) 62 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 63 | cdf_min = approx_standard_normal_cdf(min_in) 64 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 65 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 66 | cdf_delta = cdf_plus - cdf_min 67 | log_probs = th.where( 68 | x < -0.999, 69 | log_cdf_plus, 70 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 71 | ) 72 | assert log_probs.shape == x.shape 73 | return log_probs 74 | -------------------------------------------------------------------------------- /src/model/diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /src/model/loss/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .diffloss import DiffLoss 3 | from .mse import WeightedMSELoss 4 | from .perceptual import LPIPS 5 | from .discriminator import GANLoss 6 | from .multilabel_constrastive import MultilabelContrastiveLoss -------------------------------------------------------------------------------- /src/model/loss/diffloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | import math 5 | 6 | from ..diffusion import create_diffusion 7 | 8 | 9 | class DiffLoss(nn.Module): 10 | """Diffusion Loss""" 11 | def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False): 12 | super(DiffLoss, self).__init__() 13 | self.in_channels = target_channels 14 | self.net = SimpleMLPAdaLN( 15 | in_channels=target_channels, 16 | model_channels=width, 17 | out_channels=target_channels * 2, # for vlb loss 18 | z_channels=z_channels, 19 | num_res_blocks=depth, 20 | grad_checkpointing=grad_checkpointing 21 | ) 22 | 23 | self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine") 24 | self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine") 25 | 26 | def forward(self, target, z, mask=None): 27 | t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device) 28 | model_kwargs = dict(c=z) 29 | loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs) 30 | loss = loss_dict["loss"] 31 | if mask is not None: 32 | loss = (loss * mask).sum() / mask.sum() 33 | return loss.mean() 34 | 35 | def sample(self, z, temperature=1.0, cfg=1.0): 36 | # diffusion loss sampling 37 | if not cfg == 1.0: 38 | noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda() 39 | noise = torch.cat([noise, noise], dim=0) 40 | model_kwargs = dict(c=z, cfg_scale=cfg) 41 | sample_fn = self.net.forward_with_cfg 42 | else: 43 | noise = torch.randn(z.shape[0], self.in_channels).cuda() 44 | model_kwargs = dict(c=z) 45 | sample_fn = self.net.forward 46 | 47 | sampled_token_latent = self.gen_diffusion.p_sample_loop( 48 | sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False, 49 | temperature=temperature 50 | ) 51 | 52 | return sampled_token_latent 53 | 54 | 55 | def modulate(x, shift, scale): 56 | return x * (1 + scale) + shift 57 | 58 | 59 | class TimestepEmbedder(nn.Module): 60 | """ 61 | Embeds scalar timesteps into vector representations. 62 | """ 63 | def __init__(self, hidden_size, frequency_embedding_size=256): 64 | super().__init__() 65 | self.mlp = nn.Sequential( 66 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 67 | nn.SiLU(), 68 | nn.Linear(hidden_size, hidden_size, bias=True), 69 | ) 70 | self.frequency_embedding_size = frequency_embedding_size 71 | 72 | @staticmethod 73 | def timestep_embedding(t, dim, max_period=10000): 74 | """ 75 | Create sinusoidal timestep embeddings. 76 | :param t: a 1-D Tensor of N indices, one per batch element. 77 | These may be fractional. 78 | :param dim: the dimension of the output. 79 | :param max_period: controls the minimum frequency of the embeddings. 80 | :return: an (N, D) Tensor of positional embeddings. 81 | """ 82 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 83 | half = dim // 2 84 | freqs = torch.exp( 85 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 86 | ).to(device=t.device) 87 | args = t[:, None].float() * freqs[None] 88 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 89 | if dim % 2: 90 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 91 | return embedding 92 | 93 | def forward(self, t): 94 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 95 | t_emb = self.mlp(t_freq) 96 | return t_emb 97 | 98 | 99 | class ResBlock(nn.Module): 100 | """ 101 | A residual block that can optionally change the number of channels. 102 | :param channels: the number of input channels. 103 | """ 104 | 105 | def __init__( 106 | self, 107 | channels 108 | ): 109 | super().__init__() 110 | self.channels = channels 111 | 112 | self.in_ln = nn.LayerNorm(channels, eps=1e-6) 113 | self.mlp = nn.Sequential( 114 | nn.Linear(channels, channels, bias=True), 115 | nn.SiLU(), 116 | nn.Linear(channels, channels, bias=True), 117 | ) 118 | 119 | self.adaLN_modulation = nn.Sequential( 120 | nn.SiLU(), 121 | nn.Linear(channels, 3 * channels, bias=True) 122 | ) 123 | 124 | def forward(self, x, y): 125 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) 126 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp) 127 | h = self.mlp(h) 128 | return x + gate_mlp * h 129 | 130 | 131 | class FinalLayer(nn.Module): 132 | """ 133 | The final layer adopted from DiT. 134 | """ 135 | def __init__(self, model_channels, out_channels): 136 | super().__init__() 137 | self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) 138 | self.linear = nn.Linear(model_channels, out_channels, bias=True) 139 | self.adaLN_modulation = nn.Sequential( 140 | nn.SiLU(), 141 | nn.Linear(model_channels, 2 * model_channels, bias=True) 142 | ) 143 | 144 | def forward(self, x, c): 145 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) 146 | x = modulate(self.norm_final(x), shift, scale) 147 | x = self.linear(x) 148 | return x 149 | 150 | 151 | class SimpleMLPAdaLN(nn.Module): 152 | """ 153 | The MLP for Diffusion Loss. 154 | :param in_channels: channels in the input Tensor. 155 | :param model_channels: base channel count for the model. 156 | :param out_channels: channels in the output Tensor. 157 | :param z_channels: channels in the condition. 158 | :param num_res_blocks: number of residual blocks per downsample. 159 | """ 160 | 161 | def __init__( 162 | self, 163 | in_channels, 164 | model_channels, 165 | out_channels, 166 | z_channels, 167 | num_res_blocks, 168 | grad_checkpointing=False 169 | ): 170 | super().__init__() 171 | 172 | self.in_channels = in_channels 173 | self.model_channels = model_channels 174 | self.out_channels = out_channels 175 | self.num_res_blocks = num_res_blocks 176 | self.grad_checkpointing = grad_checkpointing 177 | 178 | self.time_embed = TimestepEmbedder(model_channels) 179 | self.cond_embed = nn.Linear(z_channels, model_channels) 180 | 181 | self.input_proj = nn.Linear(in_channels, model_channels) 182 | 183 | res_blocks = [] 184 | for i in range(num_res_blocks): 185 | res_blocks.append(ResBlock( 186 | model_channels, 187 | )) 188 | 189 | self.res_blocks = nn.ModuleList(res_blocks) 190 | self.final_layer = FinalLayer(model_channels, out_channels) 191 | 192 | self.initialize_weights() 193 | 194 | def initialize_weights(self): 195 | def _basic_init(module): 196 | if isinstance(module, nn.Linear): 197 | torch.nn.init.xavier_uniform_(module.weight) 198 | if module.bias is not None: 199 | nn.init.constant_(module.bias, 0) 200 | self.apply(_basic_init) 201 | 202 | # Initialize timestep embedding MLP 203 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) 204 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) 205 | 206 | # Zero-out adaLN modulation layers 207 | for block in self.res_blocks: 208 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 209 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 210 | 211 | # Zero-out output layers 212 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 214 | nn.init.constant_(self.final_layer.linear.weight, 0) 215 | nn.init.constant_(self.final_layer.linear.bias, 0) 216 | 217 | def forward(self, x, t, c): 218 | """ 219 | Apply the model to an input batch. 220 | :param x: an [N x C] Tensor of inputs. 221 | :param t: a 1-D batch of timesteps. 222 | :param c: conditioning from AR transformer. 223 | :return: an [N x C] Tensor of outputs. 224 | """ 225 | x = self.input_proj(x) 226 | t = self.time_embed(t) 227 | c = self.cond_embed(c) 228 | 229 | y = t + c 230 | 231 | if self.grad_checkpointing and not torch.jit.is_scripting(): 232 | for block in self.res_blocks: 233 | x = checkpoint(block, x, y) 234 | else: 235 | for block in self.res_blocks: 236 | x = block(x, y) 237 | 238 | return self.final_layer(x, y) 239 | 240 | def forward_with_cfg(self, x, t, c, cfg_scale): 241 | half = x[: len(x) // 2] 242 | combined = torch.cat([half, half], dim=0) 243 | model_out = self.forward(combined, t, c) 244 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 245 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 246 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 247 | eps = torch.cat([half_eps, half_eps], dim=0) 248 | return torch.cat([eps, rest], dim=1) -------------------------------------------------------------------------------- /src/model/loss/mse.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | class WeightedMSELoss(nn.Module): 4 | def __init__(self, weight=1.0): 5 | super().__init__() 6 | self.mse = nn.MSELoss(reduction='none') 7 | self.weight = weight 8 | 9 | def forward(self, pred, target, loss_mask=None, loss_weight=None): 10 | mse_loss = self.mse(pred, target) 11 | if loss_mask is not None: 12 | mse_loss = (mse_loss * loss_mask) 13 | weight = loss_mask.sum([-2, -1]) 14 | weight += 1 15 | mse_loss = mse_loss.sum([-2, -1]) 16 | mse_loss = (mse_loss / weight) 17 | else: 18 | mse_loss = mse_loss.mean([-3, -2, -1]) 19 | return mse_loss.mean() * self.weight -------------------------------------------------------------------------------- /src/model/loss/multilabel_constrastive.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange, repeat 6 | import torch.distributed as dist 7 | import numpy as np 8 | from typing import Optional, Dict, List 9 | from timm.loss import SoftTargetCrossEntropy 10 | import diffdist.functional as diff_dist 11 | 12 | 13 | 14 | def dist_collect(x): 15 | """ collect all tensor from all GPUs 16 | args: 17 | x: shape (mini_batch, ...) 18 | returns: 19 | shape (mini_batch * num_gpu, ...) 20 | """ 21 | x = x.contiguous() 22 | out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())] 23 | out_list = diff_dist.all_gather(out_list, x) 24 | return torch.cat(out_list, dim=0).contiguous() 25 | 26 | 27 | class MultilabelContrastiveLoss(nn.Module): 28 | def __init__(self, 29 | text_encoder: nn.Module, 30 | contrast_temperature: Optional[float]=0.07, 31 | multi_label: Optional[int]=0, 32 | share_temperature: Optional[bool]=False, 33 | multi_label_loss_weight: Optional[float]=1.0, 34 | **kwargs) -> None: 35 | super().__init__(MultilabelContrastiveLoss) 36 | 37 | self.text_encoder = text_encoder 38 | self.contrast_temperature = contrast_temperature 39 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature)) 40 | self.cross_entropy = nn.CrossEntropyLoss() 41 | self.soft_cross_entropy = SoftTargetCrossEntropy() 42 | 43 | self.multi_label = multi_label 44 | self.share_temperature = share_temperature 45 | if self.with_multi_label and not self.share_temperature: 46 | self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature)) 47 | self.multi_label_loss_weight = multi_label_loss_weight 48 | 49 | @property 50 | def with_multi_label(self): 51 | return self.multi_label > 0 52 | 53 | 54 | def loss(self, image_x, text_x): 55 | 56 | batch_size = image_x.shape[0] 57 | # get label globally 58 | labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device) + batch_size * dist.get_rank() 59 | 60 | # [B, C] 61 | image_x = F.normalize(image_x, dim=-1) 62 | text_x = F.normalize(text_x, dim=-1) 63 | 64 | logits_per_img = image_x @ dist_collect(text_x).t() 65 | logits_per_text = text_x @ dist_collect(image_x).t() 66 | 67 | logit_scale = torch.clamp(self.logit_scale.exp(), max=100) 68 | loss_img = self.cross_entropy(logits_per_img * logit_scale, labels) 69 | loss_text = self.cross_entropy(logits_per_text * logit_scale, labels) 70 | 71 | loss = 0.5 * (loss_img + loss_text) 72 | 73 | return loss 74 | 75 | def multi_label_loss(self, image_feat, text_feat): 76 | """ 77 | 78 | Args: 79 | image_feat (torch.Tensor): shape [B, L1, C] 80 | text_feat (torch.Tensor): shape [B, L2, C] 81 | 82 | Returns: 83 | 84 | """ 85 | # [B, L1, C], L1 = 1 86 | image_feat = F.normalize(image_feat, dim=-1) 87 | # [B, L2, C] 88 | text_feat = F.normalize(text_feat, dim=-1) 89 | 90 | # [B, L1, L2] 91 | dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l') 92 | # [B, L2, L1] 93 | dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l') 94 | 95 | if self.share_temperature: 96 | logit_scale = torch.clamp(self.logit_scale.exp(), max=100) 97 | else: 98 | logit_scale = torch.clamp(self.multi_label_logit_scale.exp(), max=100) 99 | 100 | batch = image_feat.shape[0] 101 | img_len = image_feat.shape[1] 102 | text_len = text_feat.shape[1] 103 | # [B, L1, L2] 104 | pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2') 105 | # [B, L2, L1] 106 | pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1') 107 | 108 | image_x = rearrange(image_feat, 'b l c -> (b l) c') 109 | text_x = rearrange(text_feat, 'b l c -> (b l) c') 110 | 111 | logits_per_img = image_x @ dist_collect(text_x).t() 112 | logits_per_text = text_x @ dist_collect(image_x).t() 113 | 114 | # get label globally 115 | # [B, L1, B, L2, W] 116 | labels_per_img = F.one_hot( 117 | torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * dist.get_rank(), 118 | num_classes=dist.get_world_size()).to(image_x.dtype) 119 | labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat( 120 | torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1') 121 | # [BxL1, WxBxL2] 122 | labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)') 123 | # [B, L2, B, L1, W] 124 | labels_per_text = F.one_hot( 125 | torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * dist.get_rank(), 126 | num_classes=dist.get_world_size()).to(text_x.dtype) 127 | labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat( 128 | torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1') 129 | # [BxL2, WxBxL1] 130 | labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)') 131 | 132 | loss_img = self.soft_cross_entropy(logits_per_img * logit_scale, labels_per_img) 133 | loss_text = self.soft_cross_entropy(logits_per_text * logit_scale, labels_per_text) 134 | 135 | loss = 0.5 * (loss_img + loss_text) 136 | 137 | return loss 138 | 139 | 140 | def forward(self, image_x, text_x): 141 | losses = self.loss(image_x, text_x) 142 | text_outs = self.text_encoder(text_x) 143 | losses_dict = dict(loss=losses.detach().item()) 144 | if self.with_multi_label: 145 | image_multi_label_x = image_x.unsqueeze(1) 146 | text_multi_label_x = text_outs.unsqueeze(1) 147 | multi_label_loss = self.multi_label_loss(image_multi_label_x, 148 | text_multi_label_x) * self.multi_label_loss_weight 149 | losses += multi_label_loss 150 | losses_dict.update({ 151 | "multi_label_loss": multi_label_loss.detach().item() 152 | }) 153 | return losses, losses_dict 154 | 155 | return losses, losses_dict -------------------------------------------------------------------------------- /src/model/loss/perceptual.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | from collections import namedtuple 7 | import os, hashlib 8 | import requests 9 | from tqdm import tqdm 10 | # from taming.util import get_ckpt_path 11 | 12 | URL_MAP = { 13 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 14 | } 15 | 16 | CKPT_MAP = { 17 | "vgg_lpips": "vgg.pth" 18 | } 19 | 20 | MD5_MAP = { 21 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 22 | } 23 | 24 | 25 | def download(url, local_path, chunk_size=1024): 26 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 27 | with requests.get(url, stream=True) as r: 28 | total_size = int(r.headers.get("content-length", 0)) 29 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 30 | with open(local_path, "wb") as f: 31 | for data in r.iter_content(chunk_size=chunk_size): 32 | if data: 33 | f.write(data) 34 | pbar.update(chunk_size) 35 | 36 | 37 | def md5_hash(path): 38 | with open(path, "rb") as f: 39 | content = f.read() 40 | return hashlib.md5(content).hexdigest() 41 | 42 | 43 | 44 | def get_ckpt_path(name, root, check=False): 45 | assert name in URL_MAP 46 | path = os.path.join(root, CKPT_MAP[name]) 47 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 48 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 49 | download(URL_MAP[name], path) 50 | md5 = md5_hash(path) 51 | assert md5 == MD5_MAP[name], md5 52 | return path 53 | 54 | 55 | class LPIPS(nn.Module): 56 | # Learned perceptual metric 57 | def __init__(self, use_dropout=True): 58 | super().__init__() 59 | self.scaling_layer = ScalingLayer() 60 | self.chns = [64, 128, 256, 512, 512] # vg16 features 61 | self.net = vgg16(pretrained=True, requires_grad=False) 62 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 63 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 64 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 65 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 66 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 67 | self.load_from_pretrained() 68 | for param in self.parameters(): 69 | param.requires_grad = False 70 | 71 | def load_from_pretrained(self, name="vgg_lpips"): 72 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") 73 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 74 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 75 | 76 | @classmethod 77 | def from_pretrained(cls, name="vgg_lpips"): 78 | if name != "vgg_lpips": 79 | raise NotImplementedError 80 | model = cls() 81 | ckpt = get_ckpt_path(name) 82 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 83 | return model 84 | 85 | def forward(self, input, target): 86 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 87 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 88 | feats0, feats1, diffs = {}, {}, {} 89 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 90 | for kk in range(len(self.chns)): 91 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 92 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 93 | 94 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 95 | val = res[0] 96 | for l in range(1, len(self.chns)): 97 | val += res[l] 98 | return val 99 | 100 | 101 | class ScalingLayer(nn.Module): 102 | def __init__(self): 103 | super(ScalingLayer, self).__init__() 104 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 105 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 106 | 107 | def forward(self, inp): 108 | return (inp - self.shift) / self.scale 109 | 110 | 111 | class NetLinLayer(nn.Module): 112 | """ A single linear layer which does a 1x1 conv """ 113 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 114 | super(NetLinLayer, self).__init__() 115 | layers = [nn.Dropout(), ] if (use_dropout) else [] 116 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 117 | self.model = nn.Sequential(*layers) 118 | 119 | 120 | class vgg16(torch.nn.Module): 121 | def __init__(self, requires_grad=False, pretrained=True): 122 | super(vgg16, self).__init__() 123 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 124 | self.slice1 = torch.nn.Sequential() 125 | self.slice2 = torch.nn.Sequential() 126 | self.slice3 = torch.nn.Sequential() 127 | self.slice4 = torch.nn.Sequential() 128 | self.slice5 = torch.nn.Sequential() 129 | self.N_slices = 5 130 | for x in range(4): 131 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 132 | for x in range(4, 9): 133 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 134 | for x in range(9, 16): 135 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 136 | for x in range(16, 23): 137 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 138 | for x in range(23, 30): 139 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 140 | if not requires_grad: 141 | for param in self.parameters(): 142 | param.requires_grad = False 143 | 144 | def forward(self, X): 145 | h = self.slice1(X) 146 | h_relu1_2 = h 147 | h = self.slice2(h) 148 | h_relu2_2 = h 149 | h = self.slice3(h) 150 | h_relu3_3 = h 151 | h = self.slice4(h) 152 | h_relu4_3 = h 153 | h = self.slice5(h) 154 | h_relu5_3 = h 155 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 156 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 157 | return out 158 | 159 | 160 | def normalize_tensor(x,eps=1e-10): 161 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 162 | return x/(norm_factor+eps) 163 | 164 | 165 | def spatial_average(x, keepdim=True): 166 | return x.mean([2,3],keepdim=keepdim) -------------------------------------------------------------------------------- /src/model/loss/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | ALPHA = 0.8 6 | GAMMA = 2 7 | 8 | 9 | class BCELoss(nn.Module): 10 | def forward(self, prediction, target): 11 | loss = F.binary_cross_entropy_with_logits(prediction,target) 12 | return loss, {} 13 | 14 | 15 | class BCELossWithQuant(nn.Module): 16 | def __init__(self, codebook_weight=1.): 17 | super().__init__() 18 | self.codebook_weight = codebook_weight 19 | 20 | def forward(self, qloss, target, prediction, split): 21 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target) 22 | loss = bce_loss + self.codebook_weight*qloss 23 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), 24 | "{}/bce_loss".format(split): bce_loss.detach().mean(), 25 | "{}/quant_loss".format(split): qloss.detach().mean() 26 | } 27 | 28 | """Stripped version of https://github.com/NExT-ChatV/NExT-Chat/blob/main/mllm/models/sam/sam_loss.py""" 29 | 30 | class FocalLoss(nn.Module): 31 | 32 | def __init__(self, weight=None, size_average=True): 33 | super().__init__() 34 | 35 | def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1): 36 | # inputs = F.sigmoid(inputs) 37 | # inputs = torch.clamp(inputs, min=0, max=1) 38 | #flatten label and prediction tensors 39 | inputs = inputs.view(-1) 40 | targets = targets.view(-1) 41 | BCE = F.binary_cross_entropy_with_logits(inputs, targets, reduction='mean') 42 | BCE_EXP = torch.exp(-BCE) 43 | focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE 44 | 45 | return focal_loss 46 | 47 | 48 | class DiceLoss(nn.Module): 49 | 50 | def __init__(self, weight=None, size_average=True): 51 | super().__init__() 52 | 53 | def forward(self, inputs, targets, smooth=1): 54 | inputs = F.sigmoid(inputs) 55 | inputs = torch.clamp(inputs, min=0, max=1) 56 | #flatten label and prediction tensors 57 | inputs = inputs.view(-1) 58 | targets = targets.view(-1) 59 | 60 | intersection = (inputs * targets).sum() 61 | dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth) 62 | 63 | return 1 - dice 64 | 65 | 66 | def calc_iou(pred_mask: torch.Tensor, gt_mask: torch.Tensor): 67 | pred_mask = (pred_mask >= 0.5) 68 | intersection = torch.sum(torch.mul(pred_mask, gt_mask), dim=(1, 2)) 69 | union = torch.sum(pred_mask, dim=(1, 2)) + torch.sum(gt_mask, dim=(1, 2)) - intersection 70 | epsilon = 1e-7 71 | batch_iou = intersection / (union + epsilon) 72 | 73 | # batch_iou = batch_iou.unsqueeze(1) 74 | return batch_iou 75 | 76 | 77 | class SamLoss(nn.Module): 78 | def __init__(self): 79 | super().__init__() 80 | self.focal_loss = FocalLoss() 81 | self.dice_loss = DiceLoss() 82 | 83 | def forward(self, pred_masks, gt_masks, iou_predictions, device): 84 | loss_focal = 0. 85 | loss_dice = 0. 86 | loss_iou = 0. 87 | num_masks = sum(len(pred_mask) for pred_mask in pred_masks) 88 | for pred_mask, gt_mask, iou_prediction in zip(pred_masks, gt_masks, iou_predictions): 89 | gt_mask = gt_mask.to(device) 90 | batch_iou = calc_iou(pred_mask, gt_mask) 91 | loss_focal += self.focal_loss(pred_mask, gt_mask, num_masks) 92 | loss_dice += self.dice_loss(pred_mask, gt_mask, num_masks) 93 | loss_iou += F.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks 94 | 95 | loss_total = 20. * loss_focal + loss_dice + loss_iou 96 | return loss_total -------------------------------------------------------------------------------- /src/model/make_delta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta 4 | """ 5 | import argparse 6 | 7 | import torch 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer, AutoModelForCausalLM 10 | from src.model.utils import auto_upgrade 11 | 12 | 13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id): 14 | print("Loading base model") 15 | base = AutoModelForCausalLM.from_pretrained( 16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 17 | 18 | print("Loading target model") 19 | auto_upgrade(target_model_path) 20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True) 21 | 22 | print("Calculating delta") 23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): 24 | if name not in base.state_dict(): 25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model' 26 | continue 27 | if param.data.shape == base.state_dict()[name].shape: 28 | param.data -= base.state_dict()[name] 29 | else: 30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}' 31 | bparam = base.state_dict()[name] 32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam 33 | 34 | print("Saving delta") 35 | if hub_repo_id: 36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id} 37 | else: 38 | kwargs = {} 39 | target.save_pretrained(delta_path, **kwargs) 40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path) 41 | target_tokenizer.save_pretrained(delta_path, **kwargs) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--base-model-path", type=str, required=True) 47 | parser.add_argument("--target-model-path", type=str, required=True) 48 | parser.add_argument("--delta-path", type=str, required=True) 49 | parser.add_argument("--hub-repo-id", type=str, default=None) 50 | args = parser.parse_args() 51 | 52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id) 53 | -------------------------------------------------------------------------------- /src/model/multimodal_encoder/builder.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import asdict, is_dataclass 3 | from .clip_encoder import CLIPVisionTower 4 | from ..setok.tokenizer import SetokTokenizer 5 | 6 | def build_vision_tower(vision_tower_cfg, **kwargs): 7 | vision_tower = getattr(vision_tower_cfg, 'vision_tokenizer', getattr(vision_tower_cfg, 'vision_tower', None)) 8 | # is_absolute_path_exists = os.path.exists(vision_tower) 9 | # if is_absolute_path_exists and (vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower or "vit" in vision_tower): 10 | # return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs) 11 | if is_dataclass(vision_tower_cfg): 12 | vision_tower_cfg = asdict(vision_tower_cfg) 13 | elif isinstance(vision_tower_cfg, dict): 14 | vision_tower_cfg = vision_tower_cfg 15 | else: 16 | vision_tower_cfg = vars(vision_tower_cfg) 17 | 18 | 19 | if 'siglip' in vision_tower: 20 | return SetokTokenizer(**vision_tower_cfg, **kwargs) 21 | 22 | raise ValueError(f'Unknown vision tower: {vision_tower}') 23 | -------------------------------------------------------------------------------- /src/model/multimodal_encoder/clip_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import torch.nn as nn 4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig 5 | 6 | 7 | class CLIPVisionTower(nn.Module): 8 | def __init__(self, vision_tower: Optional[str], 9 | unfreeze_mm_vision_tower: Optional[str] = False, 10 | mm_vision_select_feature: Optional[str] = 'patch', 11 | mm_vision_select_layer: Optional[int] = -2, 12 | delay_load=False): 13 | super().__init__() 14 | 15 | self.is_loaded = False 16 | 17 | self.vision_tower_name = vision_tower 18 | self.select_layer = mm_vision_select_layer 19 | self.select_feature = mm_vision_select_feature 20 | 21 | if not delay_load: 22 | self.load_model() 23 | elif unfreeze_mm_vision_tower: 24 | self.load_model() 25 | else: 26 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) 27 | 28 | def load_model(self, device_map=None): 29 | if self.is_loaded: 30 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 31 | return 32 | 33 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name) 34 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map) 35 | self.vision_tower.requires_grad_(False) 36 | 37 | self.is_loaded = True 38 | 39 | def feature_select(self, image_forward_outs): 40 | image_features = image_forward_outs.hidden_states[self.select_layer] 41 | if self.select_feature == 'patch': 42 | image_features = image_features[:, 1:] 43 | elif self.select_feature == 'cls_patch': 44 | image_features = image_features 45 | else: 46 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 47 | return image_features 48 | 49 | @torch.no_grad() 50 | def forward(self, images): 51 | if type(images) is list: 52 | image_features = [] 53 | for image in images: 54 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 55 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 56 | image_features.append(image_feature) 57 | else: 58 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 59 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 60 | 61 | return image_features 62 | 63 | @property 64 | def dummy_feature(self): 65 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 66 | 67 | @property 68 | def dtype(self): 69 | return self.vision_tower.dtype 70 | 71 | @property 72 | def device(self): 73 | return self.vision_tower.device 74 | 75 | @property 76 | def config(self): 77 | if self.is_loaded: 78 | return self.vision_tower.config 79 | else: 80 | return self.cfg_only 81 | 82 | @property 83 | def hidden_size(self): 84 | return self.config.hidden_size 85 | 86 | @property 87 | def num_patches_per_side(self): 88 | return self.config.image_size // self.config.patch_size 89 | 90 | @property 91 | def num_patches(self): 92 | return (self.config.image_size // self.config.patch_size) ** 2 93 | -------------------------------------------------------------------------------- /src/model/multimodal_encoder/openclip_encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import os 5 | import json 6 | import logging 7 | import deepspeed 8 | from pathlib import Path 9 | from open_clip.factory import load_state_dict, get_model_config 10 | from open_clip.model import CLIPVisionCfg, CLIPTextCfg, _build_vision_tower, convert_to_custom_text_state_dict, resize_pos_embed 11 | from typing import Dict, Optional 12 | from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled 13 | # from transformers import CLIPImageProcessor 14 | from .openclip_processor import OpenCLIPImageProcessor 15 | 16 | 17 | class OpenCLIPVisionTower(nn.Module): 18 | def __init__(self, vision_tower, args, delay_load=False): 19 | super().__init__() 20 | 21 | self.is_loaded = False 22 | self.vision_tower_name = vision_tower 23 | self.select_stage = args.mm_vision_select_layer if args.mm_vision_select_layer > -5 else -2 # the output stage to select for extracted features 24 | self.vision_config = json.load(open(os.path.join(vision_tower,'open_clip_config.json'), 'r')) 25 | self.is_optimize = getattr(args, 'optimize_vision_tower_aux', False) 26 | 27 | if not delay_load: 28 | self.load_model() 29 | 30 | def load_model(self): 31 | self.image_processor = OpenCLIPImageProcessor.from_pretrained('/public/models/clip/clip-vit-large-patch14') 32 | ckpt_path = os.path.join(self.vision_tower_name, 'open_clip_pytorch_model.bin') 33 | if 'convnext' in self.vision_tower_name: 34 | if 'large' in self.vision_tower_name and 'd_320' in self.vision_tower_name: 35 | self.model_type = 'convnext_large_d_320' 36 | self.model_channel = [192, 384, 768, 1536] # stage 0-3 37 | elif 'base' in self.vision_tower_name and 'w_320' in self.vision_tower_name: 38 | self.model_type = 'convnext_base_w_320' 39 | self.model_channel = [128, 256, 512, 1024] 40 | elif 'xxlarge' in self.vision_tower_name: 41 | self.model_type = 'convnext_xxlarge' 42 | self.model_channel = [384, 768, 1536, 3072] 43 | 44 | clip_model = CLIP(**get_model_config(self.model_type)) 45 | clip_model.visual.trunk.norm_pre = None 46 | clip_model.visual.trunk.head = None 47 | clip_model.visual.head = None 48 | print(f'Loading pretrained weights ({self.model_type}).') 49 | load_checkpoint(clip_model, ckpt_path, strict=False) 50 | 51 | self.is_loaded = True 52 | # decompose stem and stages blocks in vision tower 53 | self.vision_stem = clip_model.visual.trunk.stem 54 | self.vision_stages = clip_model.visual.trunk.stages 55 | self.vision_stem.requires_grad_(False) 56 | self.vision_stages.requires_grad_(False) 57 | 58 | def forward(self, images): 59 | if type(images) is list: 60 | image_features = [] 61 | for image in images: 62 | image_feature = self.backbone(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)) 63 | image_features.append(image_feature) 64 | else: 65 | image_features = self.backbone(images.to(device=self.device, dtype=self.dtype)) 66 | 67 | return image_features 68 | 69 | def backbone(self, images): 70 | # print('images: ', images.shape) 71 | if not self.is_optimize: 72 | with torch.no_grad(): 73 | results = self.basic_forward(images) 74 | else: 75 | results = self.basic_forward(images) 76 | # 448- torch.Size([1, 384, 56, 56]), torch.Size([1, 768, 28, 28]), torch.Size([1, 1536, 14, 14]), torch.Size([1, 3072, 7, 7])] 77 | # 672- torch.Size([16, 384, 168, 168]), torch.Size([16, 768, 84, 84]), torch.Size([16, 1536, 42, 42]), torch.Size([16, 3072, 21, 21]) 78 | # where hidden_size = sum(model_channel) 79 | # print('results: ', [results[_stage].shape for _stage in results]) 80 | 81 | # target_size = (results['stage_0'].shape[-2], results['stage_0'].shape[-1]) 82 | # result_cat = [] 83 | # for _stage in results: 84 | # if _stage == 'stage_0': 85 | # result_cat.append(results[_stage].contiguous()) 86 | # else: 87 | # result_cat.append(F.interpolate(results[_stage].float().contiguous() , 88 | # size=target_size, 89 | # mode='bilinear', 90 | # align_corners=False).to(dtype=results[_stage].dtype)) 91 | # result_cat = torch.cat(result_cat, dim=1) 92 | select_stage = f'stage_{4+self.select_stage}' 93 | result_cat = results[select_stage] 94 | # print("result_cat: ", result_cat.shape) 95 | 96 | return result_cat.contiguous() 97 | 98 | def basic_forward(self, images): 99 | results = {} 100 | x = self.vision_stem(images) 101 | for _idx in range(len(self.vision_stages)): 102 | x = self.vision_stages[_idx](x) 103 | results[f'stage_{_idx}'] = x 104 | return results 105 | 106 | @property 107 | def dummy_feature(self): 108 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 109 | 110 | @property 111 | def dtype(self): 112 | return self.vision_stem[0].weight.dtype 113 | 114 | @property 115 | def device(self): 116 | return self.vision_stem[0].weight.device 117 | 118 | @property 119 | def config(self): 120 | return self.vision_config 121 | 122 | @property 123 | def hidden_size(self): 124 | return self.model_channel[self.select_stage] 125 | # return sum(self.model_channel) 126 | 127 | # modified function from open_clip to support zero3 stage 128 | def load_checkpoint(model, checkpoint_path, strict=True): 129 | if Path(checkpoint_path).suffix in ('.npz', '.npy'): 130 | from open_clip.big_vision import load_big_vision_weights 131 | load_big_vision_weights(model, checkpoint_path) 132 | return {} 133 | 134 | state_dict = load_state_dict(checkpoint_path) 135 | # detect old format and make compatible with new format 136 | if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): 137 | state_dict = convert_to_custom_text_state_dict(state_dict) 138 | # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712 139 | # if 'logit_bias' not in state_dict and model.logit_bias is not None: 140 | # state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) 141 | # Certain text transformers no longer expect position_ids after transformers==4.31 142 | position_id_key = 'text.transformer.embeddings.position_ids' 143 | if position_id_key in state_dict and not hasattr(model, position_id_key): 144 | del state_dict[position_id_key] 145 | resize_pos_embed(state_dict, model) 146 | # resize_text_pos_embed(state_dict, model) 147 | #incompatible_keys = model.load_state_dict(state_dict, strict=strict) 148 | if is_deepspeed_zero3_enabled(): 149 | 150 | error_msgs = [] 151 | 152 | def load(module: nn.Module, state_dict, prefix=""): 153 | metadata = None 154 | 155 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) 156 | args = (state_dict, prefix, local_metadata, True, [], [], error_msgs) 157 | # Parameters of module and children will start with prefix. We can exit early if there are none in this 158 | # state_dict 159 | if len([key for key in state_dict if key.startswith(prefix)]) > 0: 160 | if is_deepspeed_zero3_enabled(): 161 | # In sharded models, each shard has only part of the full state_dict, so only gather 162 | # parameters that are in the current state_dict. 163 | named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False)) 164 | params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters] 165 | if len(params_to_gather) > 0: 166 | # because zero3 puts placeholders in model params, this context 167 | # manager gathers (unpartitions) the params of the current layer, then loads from 168 | # the state dict and then re-partitions them again 169 | with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): 170 | if torch.distributed.get_rank() == 0: 171 | module._load_from_state_dict(*args) 172 | else: 173 | module._load_from_state_dict(*args) 174 | 175 | for name, child in module._modules.items(): 176 | if child is not None: 177 | load(child, state_dict, prefix + name + ".") 178 | 179 | load(model, state_dict) 180 | incompatible_keys = [] 181 | else: 182 | incompatible_keys = model.load_state_dict(state_dict, strict=strict) 183 | logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}") 184 | return incompatible_keys 185 | 186 | class CLIP(nn.Module): 187 | output_dict: torch.jit.Final[bool] 188 | 189 | def __init__( 190 | self, 191 | embed_dim: int, 192 | vision_cfg: CLIPVisionCfg, 193 | text_cfg: CLIPTextCfg, 194 | quick_gelu: bool = False, 195 | cast_dtype: Optional[torch.dtype] = None, 196 | output_dict: bool = False, 197 | ): 198 | super().__init__() 199 | self.output_dict = output_dict 200 | 201 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) -------------------------------------------------------------------------------- /src/model/multimodal_encoder/openclip_processor.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPImageProcessor 2 | from transformers.image_processing_utils import BatchFeature, get_size_dict 3 | from transformers.image_transforms import get_resize_output_image_size 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | import numpy as np 9 | 10 | 11 | class OpenCLIPImageProcessor(CLIPImageProcessor): 12 | 13 | def __init__(self, **kwargs): 14 | super().__init__(**kwargs) 15 | 16 | def preprocess(self, images, **kwargs): 17 | if not isinstance(images, np.ndarray): 18 | return super().preprocess(images=images, **kwargs) 19 | 20 | do_resize = kwargs.get('do_resize', self.do_resize) 21 | size = kwargs.get('size', self.size) 22 | size = get_size_dict(size, param_name="size", default_to_square=False) 23 | do_center_crop = kwargs.get('do_center_crop', self.do_center_crop) 24 | crop_size = kwargs.get('crop_size', self.crop_size) 25 | crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) 26 | do_rescale = kwargs.get('do_rescale', self.do_rescale) 27 | rescale_factor = kwargs.get('rescale_factor', self.rescale_factor) 28 | do_normalize = kwargs.get('do_normalize', self.do_normalize) 29 | image_mean = kwargs.get('image_mean', self.image_mean) 30 | image_std = kwargs.get('image_std', self.image_std) 31 | return_tensors = kwargs.get('return_tensors', None) 32 | 33 | def resize(images, output_size): 34 | images = images.permute((0, 3, 1, 2)) 35 | images = F.interpolate(images, size=output_size, mode='bicubic') 36 | images = images.permute((0, 2, 3, 1)) 37 | return images 38 | 39 | def center_crop(images, crop_size): 40 | crop_width, crop_height = crop_size["width"], crop_size["height"] 41 | img_width, img_height = images.shape[1:3] 42 | x = (img_width - crop_width) // 2 43 | y = (img_height - crop_height) // 2 44 | images = images[:, x:x+crop_width, y:y+crop_height] 45 | return images 46 | 47 | def rescale(images, rescale_factor): 48 | images = images * rescale_factor 49 | return images 50 | 51 | def normalize(images, mean, std): 52 | mean = torch.tensor(mean) 53 | std = torch.tensor(std) 54 | images = (images - mean) / std 55 | return images 56 | 57 | images = torch.from_numpy(images).float() 58 | 59 | if do_resize: 60 | output_size = get_resize_output_image_size(images[0], size=size["shortest_edge"], default_to_square=False) 61 | images = resize(images, output_size) 62 | 63 | if do_center_crop: 64 | images = center_crop(images, crop_size) 65 | 66 | if do_rescale: 67 | images = rescale(images, rescale_factor) 68 | 69 | if do_normalize: 70 | images = normalize(images, image_mean, image_std) 71 | 72 | images = images.permute((0, 3, 1, 2)) 73 | data = {"pixel_values": images} 74 | return BatchFeature(data=data, tensor_type=return_tensors) -------------------------------------------------------------------------------- /src/model/multimodal_generator/builder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import asdict, is_dataclass 2 | from ..setok import SetokDeTokenizer 3 | 4 | def build_vision_generator(image_generator_cfg, **kwargs): 5 | if is_dataclass(image_generator_cfg): 6 | image_generator_cfg = asdict(image_generator_cfg) 7 | elif isinstance(image_generator_cfg, dict): 8 | image_generator_cfg = image_generator_cfg 9 | else: 10 | image_generator_cfg = vars(image_generator_cfg) 11 | 12 | return SetokDeTokenizer(**image_generator_cfg, **kwargs) -------------------------------------------------------------------------------- /src/model/multimodal_projector/builder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import re 4 | 5 | 6 | class IdentityMap(nn.Module): 7 | def __init__(self): 8 | super().__init__() 9 | 10 | def forward(self, x, *args, **kwargs): 11 | return x 12 | 13 | @property 14 | def config(self): 15 | return {"mm_projector_type": 'identity'} 16 | 17 | 18 | class SimpleResBlock(nn.Module): 19 | def __init__(self, channels): 20 | super().__init__() 21 | self.pre_norm = nn.LayerNorm(channels) 22 | 23 | self.proj = nn.Sequential( 24 | nn.Linear(channels, channels), 25 | nn.GELU(), 26 | nn.Linear(channels, channels) 27 | ) 28 | def forward(self, x): 29 | x = self.pre_norm(x) 30 | return x + self.proj(x) 31 | 32 | 33 | def build_vision_projector(projector_type='linear', mm_hidden_size=4096, hidden_size=3078, delay_load=False, **kwargs): 34 | # projector_type = getattr(config, 'mm_projector_type', 'linear') 35 | 36 | if projector_type == 'linear': 37 | return nn.Linear(mm_hidden_size, hidden_size) 38 | 39 | use_norm = False 40 | if "_Norm" in projector_type: 41 | use_norm = True 42 | projector_type = projector_type.replace("_Norm", "") 43 | 44 | 45 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) 46 | if mlp_gelu_match: 47 | mlp_depth = int(mlp_gelu_match.group(1)) 48 | if use_norm: 49 | modules = [ 50 | nn.Linear(mm_hidden_size, hidden_size), 51 | nn.LayerNorm(hidden_size), 52 | ] 53 | else: 54 | modules = [nn.Linear(mm_hidden_size, hidden_size)] 55 | # modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)] 56 | for _ in range(1, mlp_depth): 57 | modules.append(nn.GELU()) 58 | modules.append(nn.Linear(hidden_size, hidden_size)) 59 | return nn.Sequential(*modules) 60 | 61 | if projector_type == 'identity': 62 | return IdentityMap() 63 | 64 | raise ValueError(f'Unknown projector type: {projector_type}') 65 | -------------------------------------------------------------------------------- /src/model/setok/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokenizer import SetokTokenizer 2 | from .detokenizer import SetokDeTokenizer 3 | from .model import SeTok -------------------------------------------------------------------------------- /src/model/setok/clip_encoder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import torch.nn as nn 4 | 5 | from transformers import AutoModel, AutoProcessor, AutoConfig 6 | 7 | 8 | class CLIPVisionTower(nn.Module): 9 | def __init__(self, vision_tower: Optional[str], 10 | unfreeze_mm_vision_tower: Optional[str] = False, 11 | mm_vision_select_feature: Optional[str] = 'patch', 12 | mm_vision_select_layer: Optional[int] = -2, 13 | delay_load=False): 14 | super().__init__() 15 | 16 | self.is_loaded = False 17 | 18 | self.vision_tower_name = vision_tower 19 | self.select_layer = mm_vision_select_layer 20 | self.select_feature = mm_vision_select_feature 21 | 22 | if not delay_load: 23 | self.load_model() 24 | elif unfreeze_mm_vision_tower: 25 | self.load_model() 26 | else: 27 | self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name) 28 | 29 | def load_model(self, device_map=None): 30 | if self.is_loaded: 31 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name)) 32 | return 33 | 34 | self.image_processor = AutoProcessor.from_pretrained(self.vision_tower_name) 35 | self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, device_map=device_map) 36 | self.vision_tower.requires_grad_(False) 37 | 38 | self.is_loaded = True 39 | 40 | def feature_select(self, image_forward_outs): 41 | image_features = image_forward_outs.hidden_states[self.select_layer] 42 | if self.select_feature == 'patch': 43 | image_features = image_features[:, 1:] 44 | elif self.select_feature == 'cls_patch': 45 | image_features = image_features 46 | else: 47 | raise ValueError(f'Unexpected select feature: {self.select_feature}') 48 | return image_features 49 | 50 | @torch.no_grad() 51 | def forward(self, images): 52 | if type(images) is list: 53 | image_features = [] 54 | for image in images: 55 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True) 56 | image_feature = self.feature_select(image_forward_out).to(image.dtype) 57 | image_features.append(image_feature) 58 | else: 59 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True) 60 | image_features = self.feature_select(image_forward_outs).to(images.dtype) 61 | 62 | return image_features 63 | 64 | @property 65 | def dummy_feature(self): 66 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) 67 | 68 | @property 69 | def dtype(self): 70 | return self.vision_tower.dtype 71 | 72 | @property 73 | def device(self): 74 | return self.vision_tower.device 75 | 76 | @property 77 | def config(self): 78 | if self.is_loaded: 79 | return self.vision_tower.config 80 | else: 81 | return self.cfg_only 82 | 83 | @property 84 | def hidden_size(self): 85 | return self.config.hidden_size 86 | 87 | @property 88 | def num_patches_per_side(self): 89 | return self.config.image_size // self.config.patch_size 90 | 91 | @property 92 | def num_patches(self): 93 | return (self.config.image_size // self.config.patch_size) ** 2 94 | -------------------------------------------------------------------------------- /src/model/setok/detokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Union, Dict, List, Tuple, Optional 5 | from einops import rearrange, repeat 6 | from timm.models.vision_transformer import Block 7 | from torch.utils.checkpoint import checkpoint 8 | from transformers.models.bert import BertConfig 9 | from .module import BertModel, PositionalEncoding2D 10 | from diffusers.models.autoencoders.vae import Decoder 11 | 12 | 13 | 14 | class SetokDeTokenizer(nn.Module): 15 | def __init__(self, 16 | token_feat_dim: Optional[int] = 4096, 17 | hidden_dim: Optional[int] = 4096, 18 | patch_size: Optional[int]=14, 19 | image_size: Optional[int]=256, 20 | decoder_embed_dim: Optional[int]=4096, 21 | decoder_nheads: Optional[int]=16, 22 | proj_drop: Optional[float]=0.2, 23 | attn_drop: Optional[float]=0.2, 24 | decoder_depth: Optional[int]=16, 25 | norm_layer: nn.Module = nn.LayerNorm, 26 | mlp_ratio: Optional[float]=4.0, 27 | feature_mapper_path_or_name: Optional[str]="bert-base-uncased", 28 | num_hidden_layers: Optional[int]=6, 29 | cross_attention_freq: Optional[int]=2, 30 | initializer_range: Optional[float]=0.02, 31 | **kwargs) -> None: 32 | super().__init__() 33 | self.token_feat_dim = token_feat_dim 34 | 35 | self.patch_size = patch_size 36 | self.height = self.weight = image_size // patch_size 37 | self.num_mask_token = self.height * self.weight 38 | self.hidden_dim = hidden_dim 39 | 40 | query_tokens = nn.Parameter(torch.zeros(1, self.num_mask_token, self.hidden_dim)) 41 | query_tokens.data.normal_(mean=0.0, std=initializer_range) 42 | self.mask_tokens = query_tokens 43 | 44 | self.decoder_embed_dim = decoder_embed_dim 45 | self.mapper_fc_in = nn.Linear(self.token_feat_dim, self.hidden_dim) 46 | self.decoder_fc_in = nn.Linear(self.hidden_dim, self.decoder_embed_dim) 47 | 48 | self.decoder_norm = norm_layer(self.decoder_embed_dim) 49 | self.pixel_decoder = nn.ModuleList([ 50 | Block(self.decoder_embed_dim, decoder_nheads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, proj_drop=proj_drop, attn_drop=attn_drop) for _ in range(decoder_depth) 51 | ]) 52 | self.position_embedding = PositionalEncoding2D(self.hidden_dim) 53 | self.initialize_weights() 54 | self.init_feature_mapper(feature_mapper_path_or_name, self.hidden_dim, self.num_mask_token, num_hidden_layers, cross_attention_freq) 55 | 56 | def initialize_weights(self): 57 | self.apply(self._init_weights) 58 | 59 | def _init_weights(self, m): 60 | if isinstance(m, nn.Linear): 61 | # we use xavier_uniform following official JAX ViT: 62 | torch.nn.init.xavier_uniform_(m.weight) 63 | if isinstance(m, nn.Linear) and m.bias is not None: 64 | nn.init.constant_(m.bias, 0) 65 | elif isinstance(m, nn.LayerNorm): 66 | if m.bias is not None: 67 | nn.init.constant_(m.bias, 0) 68 | if m.weight is not None: 69 | nn.init.constant_(m.weight, 1.0) 70 | 71 | def init_feature_mapper( 72 | self, 73 | feature_mapper_path_or_name: str, 74 | vision_width: int, 75 | num_mask_token: int, 76 | num_hidden_layers: int, 77 | cross_attention_freq: int 78 | ): 79 | print("feature_mapper_path_or_name: ", feature_mapper_path_or_name) 80 | mapper_config = BertConfig.from_pretrained(feature_mapper_path_or_name) 81 | 82 | mapper_config.encoder_width = vision_width 83 | # insert cross-attention layer every other block 84 | mapper_config.add_cross_attention = True 85 | 86 | mapper_config.cross_attention_freq = cross_attention_freq 87 | mapper_config.query_length = num_mask_token 88 | mapper_config.num_hidden_layers = num_hidden_layers 89 | 90 | self.mapper = BertModel.from_pretrained(feature_mapper_path_or_name, config=mapper_config) 91 | self.mapper.cls = None 92 | self.mapper.embeddings.word_embeddings = None 93 | self.mapper.embeddings.position_embeddings = None 94 | for layer in self.mapper.encoder.layer: 95 | layer.output = None 96 | layer.intermediate = None 97 | 98 | def load_model(self): 99 | pass 100 | 101 | def forward(self, x, attention_masks): 102 | 103 | mask_tokens = self.mask_tokens.expand(x.shape[0], -1, -1) 104 | x = self.mapper_fc_in(x) 105 | x = self.mapper( 106 | query_embeds=mask_tokens, 107 | encoder_hidden_states=x, 108 | encoder_attention_mask=attention_masks, 109 | return_dict=True).last_hidden_state 110 | 111 | x = self.decoder_fc_in(x) # b, h*w, c 112 | _x = rearrange(x, 'B (h w) C -> B h w C', h=self.height, w=self.weight) 113 | pos_emb = self.position_embedding(_x) 114 | pos_emb = rearrange(pos_emb, 'B h w C -> B (h w) C') 115 | x = x + pos_emb 116 | 117 | for block in self.pixel_decoder: 118 | x = block(x) 119 | 120 | x = self.decoder_norm(x) 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /src/model/setok/model.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as f 5 | from typing import Union, Dict, List,Tuple, Optional 6 | import numpy as np 7 | from dataclasses import dataclass, asdict, is_dataclass 8 | from transformers.utils import ModelOutput 9 | from .utils import * 10 | from .tokenizer import SetokTokenizer 11 | from .detokenizer import SetokDeTokenizer 12 | from ..loss import GANLoss, MultilabelContrastiveLoss 13 | from ..utils import instantiate_from_config 14 | 15 | 16 | @dataclass 17 | class SetokOutput(ModelOutput): 18 | token_emb: torch.FloatTensor = None 19 | predict_emb: torch.FloatTensor = None 20 | loss: torch.FloatTensor = None, 21 | loss_log: Dict = None, 22 | 23 | 24 | 25 | class SeTok(nn.Module): 26 | def __init__(self, 27 | tokenizer_config, 28 | detokenizer_config, 29 | rec_loss_config=None, 30 | contrastive_loss_config=None, 31 | is_training=False, 32 | **kwargs) -> None: 33 | super(SeTok).__init__() 34 | 35 | self.tokenizer_config = tokenizer_config 36 | self.detokenizer_config = detokenizer_config 37 | 38 | self.tokenizer = SetokTokenizer(**asdict(tokenizer_config)) 39 | 40 | self.detokenizer = SetokDeTokenizer(**asdict(detokenizer_config)) 41 | 42 | self.is_training = is_training 43 | if is_training: 44 | self.rec_loss = GANLoss(**asdict(rec_loss_config)) 45 | self.contrastive_loss = MultilabelContrastiveLoss(**asdict(contrastive_loss_config)) 46 | 47 | def get_tokenizer_config(self): 48 | return self.tokenizer_config 49 | 50 | def get_detokenizer_config(self): 51 | return self.detokenizer_config 52 | 53 | def get_tokenizer(self): 54 | return self.tokenizer 55 | 56 | def get_detokenizer(self): 57 | return self.detokenizer 58 | 59 | def tokenize(self, x): 60 | return self.tokenizer(x) 61 | 62 | def detokenize(self, x): 63 | return self.detokenizer(x) 64 | 65 | def sample_orders(self, bsz): 66 | # generate a batch of random generation orders 67 | orders = [] 68 | for _ in range(bsz): 69 | order = np.array(list(range(self.seq_len))) 70 | np.random.shuffle(order) 71 | orders.append(order) 72 | orders = torch.Tensor(np.array(orders)).cuda().long() 73 | return orders 74 | 75 | 76 | 77 | def compute_rec_loss(self, prediction, target, current_step): 78 | loss, log_dict = self.rec_loss(target, prediction, current_step) 79 | return loss, log_dict 80 | 81 | def compute_contrastive_loss(self, prediction, text): 82 | loss, log_dict = self.contrastive_loss(prediction, text) 83 | return loss, log_dict 84 | 85 | 86 | def forward(self, x, gold_image=None, text=None, return_dict=True, current_step=0): 87 | e_tokens, _, _ = self.tokenize(x) 88 | prediction = self.detokenize(e_tokens) 89 | loss = None 90 | loss_log = dict() 91 | if gold_image is not None: 92 | rec_loss, rec_loss_log = self.compute_rec_loss(prediction, gold_image, current_step) 93 | loss = rec_loss 94 | loss_log.update(**rec_loss_log) 95 | 96 | if text is not None: 97 | txt_contrastive_loss, txt_contrastive_log = self.compute_contrastive_loss(e_tokens, text) 98 | loss += txt_contrastive_loss 99 | loss_log.update(**txt_contrastive_log) 100 | if return_dict: 101 | SetokOutput(token_emb=e_tokens, predict_emb=prediction, loss=loss, loss_log=loss_log) 102 | else: 103 | loss, (e_tokens, prediction, loss_log) 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /src/model/setok/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Union, Dict, List, Optional 5 | import numpy as np 6 | from einops import rearrange, repeat 7 | from timm.models.layers import DropPath 8 | import math 9 | from .clip_encoder import CLIPVisionTower 10 | from .module import Block, PositionalEncoding2D 11 | 12 | 13 | class SetokTokenizer(nn.Module): 14 | def __init__(self, 15 | vision_tower: str = 'google/siglip-so400m-patch14-384', 16 | unfreeze_mm_vision_tower: Optional[bool] = False, 17 | mm_vision_select_feature: Optional[str] = 'patch', 18 | mm_vision_select_layer: Optional[int] = -2, 19 | delay_load: Optional[bool]= False, 20 | hidden_dim: Optional[int] = 4096, 21 | token_feat_dim: Optional[int] = 4096, 22 | min_cluster_num: Optional[int] = 64, 23 | threshold: Optional[float] = 0.5, 24 | nheads: Optional[int] = 2, 25 | dim_feedforward: Optional[int] = 4096, 26 | proj_drop: Optional[float] = 0.2, 27 | drop_path: Optional[float] = 0.0, 28 | inner_cluster_layers: Optional[int] = 2, 29 | intra_cluster_layers: Optional[int] = 2, 30 | attn_drop: Optional[float] = 0.0, 31 | act_layer: nn.Module = nn.GELU, 32 | norm_layer: nn.Module = nn.LayerNorm, 33 | **kwargs 34 | ) -> None: 35 | super().__init__() 36 | 37 | self.hidden_dim = hidden_dim 38 | self.token_feat_dim = token_feat_dim 39 | 40 | self.inner_encoder = Block(self.hidden_dim, nheads, dim_feedforward, proj_drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer, depth=inner_cluster_layers) 41 | self.inter_encoder = Block(self.hidden_dim, nheads, dim_feedforward, proj_drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer, depth=intra_cluster_layers) 42 | self.position_embedding = PositionalEncoding2D(self.hidden_dim) 43 | self.out = nn.Linear(self.hidden_dim, self.token_feat_dim) 44 | 45 | self.min_cluster_num = min_cluster_num 46 | self.threshold = threshold 47 | 48 | self.initialize_weights() 49 | 50 | self.image_feature_encoder = CLIPVisionTower(vision_tower, 51 | unfreeze_mm_vision_tower=unfreeze_mm_vision_tower, 52 | mm_vision_select_feature=mm_vision_select_feature, 53 | mm_vision_select_layer=mm_vision_select_layer, 54 | delay_load=delay_load 55 | ) 56 | self.image_processor = self.image_feature_encoder.image_processor 57 | 58 | 59 | def initialize_weights(self): 60 | self.apply(self._init_weights) 61 | 62 | def _init_weights(self, m): 63 | if isinstance(m, nn.Linear): 64 | # we use xavier_uniform following official JAX ViT: 65 | torch.nn.init.xavier_uniform_(m.weight) 66 | if isinstance(m, nn.Linear) and m.bias is not None: 67 | nn.init.constant_(m.bias, 0) 68 | elif isinstance(m, nn.LayerNorm): 69 | if m.bias is not None: 70 | nn.init.constant_(m.bias, 0) 71 | if m.weight is not None: 72 | nn.init.constant_(m.weight, 1.0) 73 | 74 | @property 75 | def dtype(self): 76 | return self.Linear.weight.dtype 77 | 78 | def cluster_dpc_knn(self, x, k, token_mask=None, threshold=0.53): 79 | with torch.no_grad(): 80 | N, C = x.shape 81 | 82 | dist_matrix = torch.cdist(x, x) / (C ** 0.5) # C * C 83 | 84 | if token_mask is not None: 85 | token_mask = token_mask > 0 86 | dist_matrix = dist_matrix * token_mask[None, :] + (dist_matrix.max() + 1) * (~token_mask[None, :]) 87 | 88 | dist_nearest, index_nearest = torch.topk(dist_matrix, k=k, dim=-1, largest=False) # C * k 89 | 90 | density = (-(dist_nearest ** 2).mean(dim=-1)).exp() # C 91 | density = density + torch.rand(density.shape, device=density.device, dtype=density.dtype) * 1e-6 # C 92 | 93 | if token_mask is not None: 94 | density = density * token_mask 95 | 96 | mask = density[None, :] > density[:, None] # C * C 97 | mask = mask.type(x.dtype) 98 | dist_max = dist_matrix.flatten(1).max(dim=-1)[0][None, None] # C * C 99 | dist, index_parent = (dist_matrix * mask + dist_max * (1 - mask)).min(dim=-1) # 1 * C, 1 * C 100 | 101 | score = dist * density 102 | 103 | index_down = torch.nonzero(score.reshape(-1)>threshold).reshape(-1) # obtain the index of the center 104 | if index_down.numel() == 0: 105 | _, index_down = torch.topk(score, k=self.min_cluster_num, dim=-1) 106 | index_down = torch.sort(index_down).values 107 | index_down = index_down.reshape(-1) 108 | 109 | # obtain the index of the cluster that each token belongs to 110 | # dist_matrix = index_points(dist_matrix, index_down.squeeze()) # the cluster_num * C 111 | dist_matrix = dist_matrix[index_down, :] # the cluster_num * C 112 | 113 | idx_cluster = dist_matrix.argmin(dim=0) # the cluster_num 114 | 115 | # B = 1 116 | # idx_batch = torch.arange(B, device=x.device)[:, None].expand(cluster_num) 117 | cluster_num = index_down.size(0) 118 | idx_tmp = torch.arange(cluster_num, device=x.device)[None, :] 119 | idx_cluster[index_down] = idx_tmp.reshape(-1) 120 | 121 | return index_down, idx_cluster, score 122 | 123 | def group_encoding(self, x, centers, labels): 124 | """ 125 | We apply transformer within each group to modeling the features. 126 | Specifically, we take the center representation as the initial representation of CLS. 127 | Then, we take the CLS representation as the final representation of the group, i.e., the concept-level visual token. 128 | Args: 129 | x of size (W, C), 130 | centers of size (L, C) 131 | label of size (W) 132 | 133 | Return: 134 | Output: group features of size (L, C) 135 | """ 136 | 137 | W, C = x.size() 138 | L, _ = centers.size() 139 | 140 | # Compute masks for each unique label 141 | unique_labels, label_counts = labels.unique(return_counts=True) 142 | # print('unique_labels: ', unique_labels) 143 | masks = [labels == cur_label for cur_label in unique_labels] # L, W 144 | # centers = centers.index_select(1, unique_labels) 145 | 146 | group_features = [] 147 | for i, m in enumerate(masks): 148 | # _m = m.unsqueeze(1).expand(W, C) 149 | # cur_length = torch.sum(m).item() 150 | _cur_cluster_feat = self.inner_encoder(x[m].unsqueeze(0)) 151 | _cur_cluster_feat = _cur_cluster_feat.squeeze(0).mean(dim=0) 152 | group_features.append(_cur_cluster_feat) 153 | group_features = torch.stack(group_features, dim=0) 154 | 155 | return group_features 156 | 157 | def forward(self, x, k=None, threshold=None, token_mask=None): 158 | """ 159 | Expected Input: x of size (B, h, w, C) 160 | """ 161 | x = self.image_feature_encoder(x) 162 | x = x.unsqueeze(0) 163 | B, hw, C = x.shape 164 | h = w = int(math.sqrt(x.shape[1])) 165 | _x = rearrange(x, 'B (h w) C -> B h w C', h=h, w=w) 166 | pos_emb = self.position_embedding(_x) 167 | pos_emb = rearrange(pos_emb, 'B h w C -> B (h w) C') 168 | x = x + pos_emb 169 | x = x.squeeze(0) 170 | 171 | _threshold = threshold if threshold else self.threshold 172 | _k = k if k else self.min_cluster_num 173 | 174 | index_down, idx_cluster, score = self.cluster_dpc_knn(x, _k, token_mask, _threshold) 175 | # index_down: the center index w.r.t the input x for each cluster 176 | # idx_cluster: the cluster index for each token in x, which is the index of the center in list [index_down] 177 | centers = x[index_down, :] 178 | group_features = self.group_encoding(x, centers, idx_cluster) 179 | group_features = self.inter_encoder(group_features) 180 | group_features = self.out(group_features) 181 | 182 | return group_features, idx_cluster, score -------------------------------------------------------------------------------- /src/model/setok/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | 5 | def get_emb(sin_inp): 6 | """ 7 | Gets a base embedding for one dimension with sin and cos intertwined 8 | """ 9 | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) 10 | return torch.flatten(emb, -2, -1) 11 | 12 | 13 | 14 | def mask_by_order(mask_len, order, bsz, seq_len): 15 | masking = torch.zeros(bsz, seq_len).cuda() 16 | masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool() 17 | return masking -------------------------------------------------------------------------------- /src/model/utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig 2 | import importlib 3 | 4 | 5 | 6 | def instantiate_from_config(config): 7 | if not "target" in config: 8 | if config == '__is_first_stage__': 9 | return None 10 | elif config == "__is_unconditional__": 11 | return None 12 | raise KeyError("Expected key `target` to instantiate.") 13 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 14 | 15 | 16 | def get_obj_from_str(string, reload=False): 17 | module, cls = string.rsplit(".", 1) 18 | if reload: 19 | module_imp = importlib.import_module(module) 20 | importlib.reload(module_imp) 21 | return getattr(importlib.import_module(module, package=None), cls) 22 | 23 | 24 | def auto_upgrade(config): 25 | cfg = AutoConfig.from_pretrained(config) 26 | if 'llava' in config and 'llava' not in cfg.model_type: 27 | assert cfg.model_type == 'llama' 28 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.") 29 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).") 30 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]") 31 | if confirm.lower() in ["y", "yes"]: 32 | print("Upgrading checkpoint...") 33 | assert len(cfg.architectures) == 1 34 | setattr(cfg.__class__, "model_type", "llava") 35 | cfg.architectures[0] = 'LlavaLlamaForCausalLM' 36 | cfg.save_pretrained(config) 37 | print("Checkpoint upgraded.") 38 | else: 39 | print("Checkpoint upgrade aborted.") 40 | exit(1) 41 | -------------------------------------------------------------------------------- /src/train/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | import warnings 3 | 4 | import torch 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv 8 | 9 | try: 10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 11 | except ImportError: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | from flash_attn.bert_padding import unpad_input, pad_input 14 | 15 | 16 | def forward( 17 | self, 18 | hidden_states: torch.Tensor, 19 | attention_mask: Optional[torch.Tensor] = None, 20 | position_ids: Optional[torch.Tensor] = None, 21 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 22 | output_attentions: bool = False, 23 | use_cache: bool = False, 24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 25 | if output_attentions: 26 | warnings.warn( 27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead." 28 | ) 29 | 30 | bsz, q_len, _ = hidden_states.size() 31 | 32 | query_states = ( 33 | self.q_proj(hidden_states) 34 | .view(bsz, q_len, self.num_heads, self.head_dim) 35 | .transpose(1, 2) 36 | ) 37 | key_states = ( 38 | self.k_proj(hidden_states) 39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 40 | .transpose(1, 2) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim) 45 | .transpose(1, 2) 46 | ) # shape: (b, num_heads, s, head_dim) 47 | 48 | kv_seq_len = key_states.shape[-2] 49 | if past_key_value is not None: 50 | kv_seq_len += past_key_value[0].shape[-2] 51 | 52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 53 | query_states, key_states = apply_rotary_pos_emb( 54 | query_states, key_states, cos, sin, position_ids 55 | ) 56 | 57 | if past_key_value is not None: 58 | # reuse k, v 59 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 60 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 61 | 62 | past_key_value = (key_states, value_states) if use_cache else None 63 | 64 | # repeat k/v heads if n_kv_heads < n_heads 65 | key_states = repeat_kv(key_states, self.num_key_value_groups) 66 | value_states = repeat_kv(value_states, self.num_key_value_groups) 67 | 68 | # Transform the data into the format required by flash attention 69 | qkv = torch.stack([query_states, key_states, value_states], dim=2) 70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim] 71 | key_padding_mask = attention_mask 72 | 73 | if key_padding_mask is None: 74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim) 75 | cu_q_lens = torch.arange( 76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 77 | ) 78 | max_s = q_len 79 | output = flash_attn_unpadded_qkvpacked_func( 80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 81 | ) 82 | output = output.view(bsz, q_len, -1) 83 | else: 84 | qkv = qkv.reshape(bsz, q_len, -1) 85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask) 86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) 87 | output_unpad = flash_attn_unpadded_qkvpacked_func( 88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 89 | ) 90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim) 91 | output = pad_input(output_unpad, indices, bsz, q_len) 92 | 93 | return self.o_proj(output), None, past_key_value 94 | 95 | 96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 97 | # requires the attention mask to be the same as the key_padding_mask 98 | def _prepare_decoder_attention_mask( 99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 100 | ): 101 | # [bsz, seq_len] 102 | return attention_mask 103 | 104 | 105 | def replace_llama_attn_with_flash_attn(): 106 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 107 | if cuda_major < 8: 108 | warnings.warn( 109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 111 | ) 112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 113 | _prepare_decoder_attention_mask 114 | ) 115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 116 | -------------------------------------------------------------------------------- /src/train/llama_xformers_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | """ 2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments 3 | """ 4 | 5 | import logging 6 | import math 7 | from typing import Optional, Tuple 8 | 9 | import torch 10 | import transformers.models.llama.modeling_llama 11 | from torch import nn 12 | 13 | try: 14 | import xformers.ops 15 | except ImportError: 16 | logging.error("xformers not found! Please install it before trying to use it.") 17 | 18 | 19 | def replace_llama_attn_with_xformers_attn(): 20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 21 | 22 | 23 | def xformers_forward( 24 | self, 25 | hidden_states: torch.Tensor, 26 | attention_mask: Optional[torch.Tensor] = None, 27 | position_ids: Optional[torch.LongTensor] = None, 28 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 29 | output_attentions: bool = False, 30 | use_cache: bool = False, 31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 32 | # pylint: disable=duplicate-code 33 | bsz, q_len, _ = hidden_states.size() 34 | 35 | query_states = ( 36 | self.q_proj(hidden_states) 37 | .view(bsz, q_len, self.num_heads, self.head_dim) 38 | .transpose(1, 2) 39 | ) 40 | key_states = ( 41 | self.k_proj(hidden_states) 42 | .view(bsz, q_len, self.num_heads, self.head_dim) 43 | .transpose(1, 2) 44 | ) 45 | value_states = ( 46 | self.v_proj(hidden_states) 47 | .view(bsz, q_len, self.num_heads, self.head_dim) 48 | .transpose(1, 2) 49 | ) 50 | 51 | kv_seq_len = key_states.shape[-2] 52 | if past_key_value is not None: 53 | kv_seq_len += past_key_value[0].shape[-2] 54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 55 | ( 56 | query_states, 57 | key_states, 58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb( 59 | query_states, key_states, cos, sin, position_ids 60 | ) 61 | # [bsz, nh, t, hd] 62 | 63 | if past_key_value is not None: 64 | # reuse k, v, self_attention 65 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 66 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 67 | 68 | past_key_value = (key_states, value_states) if use_cache else None 69 | 70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix 71 | if not output_attentions: 72 | query_states = query_states.transpose(1, 2) 73 | key_states = key_states.transpose(1, 2) 74 | value_states = value_states.transpose(1, 2) 75 | 76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros. 77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros. 78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0: 79 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 80 | attn_output = xformers.ops.memory_efficient_attention( 81 | query_states, key_states, value_states, attn_bias=None 82 | ) 83 | else: 84 | # input and output should be of form (bsz, q_len, num_heads, head_dim) 85 | attn_output = xformers.ops.memory_efficient_attention( 86 | query_states, 87 | key_states, 88 | value_states, 89 | attn_bias=xformers.ops.LowerTriangularMask(), 90 | ) 91 | attn_weights = None 92 | else: 93 | attn_weights = torch.matmul( 94 | query_states, key_states.transpose(2, 3) 95 | ) / math.sqrt(self.head_dim) 96 | 97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 98 | raise ValueError( 99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 100 | f" {attn_weights.size()}" 101 | ) 102 | 103 | if attention_mask is not None: 104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 105 | raise ValueError( 106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 107 | ) 108 | attn_weights = attn_weights + attention_mask 109 | attn_weights = torch.max( 110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min) 111 | ) 112 | 113 | # upcast attention to fp32 114 | attn_weights = nn.functional.softmax( 115 | attn_weights, dim=-1, dtype=torch.float32 116 | ).to(query_states.dtype) 117 | attn_output = torch.matmul(attn_weights, value_states) 118 | 119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 120 | raise ValueError( 121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 122 | f" {attn_output.size()}" 123 | ) 124 | 125 | attn_output = attn_output.transpose(1, 2) 126 | 127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 128 | attn_output = self.o_proj(attn_output) 129 | return attn_output, attn_weights, past_key_value 130 | -------------------------------------------------------------------------------- /src/train/setok_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | from torch.utils.data import Sampler 6 | import time 7 | import sys 8 | 9 | from transformers import Trainer 10 | from transformers.trainer import ( 11 | is_sagemaker_mp_enabled, 12 | get_parameter_names, 13 | has_length, 14 | ALL_LAYERNORM_LAYERS, 15 | logger, 16 | ) 17 | # from transformers.trainer import * 18 | from typing import List, Optional, Dict, Union, Any 19 | 20 | 21 | def maybe_zero_3(param, ignore_status=False, name=None): 22 | from deepspeed import zero 23 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 24 | if hasattr(param, "ds_id"): 25 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 26 | if not ignore_status: 27 | print(name, 'no ignore status') 28 | with zero.GatheredParameters([param]): 29 | param = param.data.detach().cpu().clone() 30 | else: 31 | param = param.detach().cpu().clone() 32 | return param 33 | 34 | 35 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 36 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 37 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 38 | return to_return 39 | 40 | 41 | def split_to_even_chunks(indices, lengths, num_chunks): 42 | """ 43 | Split a list of indices into `chunks` chunks of roughly equal lengths. 44 | """ 45 | 46 | if len(indices) % num_chunks != 0: 47 | return [indices[i::num_chunks] for i in range(num_chunks)] 48 | 49 | num_indices_per_chunk = len(indices) // num_chunks 50 | 51 | chunks = [[] for _ in range(num_chunks)] 52 | chunks_lengths = [0 for _ in range(num_chunks)] 53 | for index in indices: 54 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 55 | chunks[shortest_chunk].append(index) 56 | chunks_lengths[shortest_chunk] += lengths[index] 57 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 58 | chunks_lengths[shortest_chunk] = float("inf") 59 | 60 | return chunks 61 | 62 | 63 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): 64 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 65 | assert all(l != 0 for l in lengths), "Should not have zero length." 66 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): 67 | # all samples are in the same modality 68 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) 69 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) 70 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) 71 | 72 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] 73 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] 74 | megabatch_size = world_size * batch_size 75 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] 76 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] 77 | 78 | last_mm = mm_megabatches[-1] 79 | last_lang = lang_megabatches[-1] 80 | additional_batch = last_mm + last_lang 81 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] 82 | megabatch_indices = torch.randperm(len(megabatches), generator=generator) 83 | megabatches = [megabatches[i] for i in megabatch_indices] 84 | 85 | if len(additional_batch) > 0: 86 | megabatches.append(sorted(additional_batch)) 87 | 88 | return [i for megabatch in megabatches for i in megabatch] 89 | 90 | 91 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 92 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 93 | indices = torch.randperm(len(lengths), generator=generator) 94 | megabatch_size = world_size * batch_size 95 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 96 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] 97 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] 98 | 99 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 100 | 101 | 102 | class LengthGroupedSampler(Sampler): 103 | r""" 104 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 105 | keeping a bit of randomness. 106 | """ 107 | 108 | def __init__( 109 | self, 110 | batch_size: int, 111 | world_size: int, 112 | lengths: Optional[List[int]] = None, 113 | generator=None, 114 | group_by_modality: bool = False, 115 | ): 116 | if lengths is None: 117 | raise ValueError("Lengths must be provided.") 118 | 119 | self.batch_size = batch_size 120 | self.world_size = world_size 121 | self.lengths = lengths 122 | self.generator = generator 123 | self.group_by_modality = group_by_modality 124 | 125 | def __len__(self): 126 | return len(self.lengths) 127 | 128 | def __iter__(self): 129 | if self.group_by_modality: 130 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 131 | else: 132 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 133 | return iter(indices) 134 | 135 | 136 | class SetokTrainer(Trainer): 137 | 138 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 139 | if self.train_dataset is None or not has_length(self.train_dataset): 140 | return None 141 | 142 | if self.args.group_by_modality_length: 143 | lengths = self.train_dataset.modality_lengths 144 | return LengthGroupedSampler( 145 | self.args.train_batch_size, 146 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 147 | lengths=lengths, 148 | group_by_modality=True, 149 | ) 150 | else: 151 | return super()._get_train_sampler() 152 | 153 | def create_optimizer(self): 154 | """ 155 | Setup the optimizer. 156 | 157 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 158 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 159 | """ 160 | if is_sagemaker_mp_enabled(): 161 | return super().create_optimizer() 162 | 163 | opt_model = self.model 164 | 165 | if self.optimizer is None: 166 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 167 | 168 | self.optimizer = optimizer_cls(opt_model.parameters(), **optimizer_kwargs) 169 | if optimizer_cls.__name__ == "Adam8bit": 170 | import bitsandbytes 171 | 172 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 173 | 174 | skipped = 0 175 | for module in opt_model.modules(): 176 | if isinstance(module, nn.Embedding): 177 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) 178 | logger.info(f"skipped {module}: {skipped/2**20}M params") 179 | manager.register_module_override(module, "weight", {"optim_bits": 32}) 180 | logger.debug(f"bitsandbytes: will optimize {module} in fp32") 181 | logger.info(f"skipped: {skipped/2**20}M params") 182 | 183 | return self.optimizer 184 | 185 | def _save_checkpoint(self, model, trial, metrics=None): 186 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 187 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 188 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 189 | 190 | run_dir = self._get_output_dir(trial=trial) 191 | output_dir = os.path.join(run_dir, checkpoint_folder) 192 | 193 | # Only save Adapter 194 | keys_to_match = ['mm_projector', 'vision_resampler'] 195 | if getattr(self.args, "use_im_start_end", False): 196 | keys_to_match.extend(['embed_tokens', 'embed_in']) 197 | 198 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 199 | 200 | if self.args.local_rank == 0 or self.args.local_rank == -1: 201 | self.model.config.save_pretrained(output_dir) 202 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 203 | else: 204 | super(SetokimTrainer, self)._save_checkpoint(model, trial, metrics) 205 | 206 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 207 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 208 | pass 209 | else: 210 | super(SetokimTrainer, self)._save(output_dir, state_dict) -------------------------------------------------------------------------------- /src/train/setokim_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import math 5 | from torch.utils.data import Sampler 6 | import time 7 | import sys 8 | 9 | from transformers import Trainer 10 | from transformers.trainer import ( 11 | is_sagemaker_mp_enabled, 12 | get_parameter_names, 13 | has_length, 14 | ALL_LAYERNORM_LAYERS, 15 | logger, 16 | ) 17 | # from transformers.trainer import * 18 | from typing import List, Optional, Dict, Union, Any 19 | 20 | 21 | def maybe_zero_3(param, ignore_status=False, name=None): 22 | from deepspeed import zero 23 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 24 | if hasattr(param, "ds_id"): 25 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE: 26 | if not ignore_status: 27 | print(name, 'no ignore status') 28 | with zero.GatheredParameters([param]): 29 | param = param.data.detach().cpu().clone() 30 | else: 31 | param = param.detach().cpu().clone() 32 | return param 33 | 34 | 35 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 36 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 37 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()} 38 | return to_return 39 | 40 | 41 | def split_to_even_chunks(indices, lengths, num_chunks): 42 | """ 43 | Split a list of indices into `chunks` chunks of roughly equal lengths. 44 | """ 45 | 46 | if len(indices) % num_chunks != 0: 47 | return [indices[i::num_chunks] for i in range(num_chunks)] 48 | 49 | num_indices_per_chunk = len(indices) // num_chunks 50 | 51 | chunks = [[] for _ in range(num_chunks)] 52 | chunks_lengths = [0 for _ in range(num_chunks)] 53 | for index in indices: 54 | shortest_chunk = chunks_lengths.index(min(chunks_lengths)) 55 | chunks[shortest_chunk].append(index) 56 | chunks_lengths[shortest_chunk] += lengths[index] 57 | if len(chunks[shortest_chunk]) == num_indices_per_chunk: 58 | chunks_lengths[shortest_chunk] = float("inf") 59 | 60 | return chunks 61 | 62 | 63 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None): 64 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 65 | assert all(l != 0 for l in lengths), "Should not have zero length." 66 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths): 67 | # all samples are in the same modality 68 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator) 69 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0]) 70 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0]) 71 | 72 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)] 73 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)] 74 | megabatch_size = world_size * batch_size 75 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)] 76 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)] 77 | 78 | last_mm = mm_megabatches[-1] 79 | last_lang = lang_megabatches[-1] 80 | additional_batch = last_mm + last_lang 81 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1] 82 | megabatch_indices = torch.randperm(len(megabatches), generator=generator) 83 | megabatches = [megabatches[i] for i in megabatch_indices] 84 | 85 | if len(additional_batch) > 0: 86 | megabatches.append(sorted(additional_batch)) 87 | 88 | return [i for megabatch in megabatches for i in megabatch] 89 | 90 | 91 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True): 92 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch. 93 | indices = torch.randperm(len(lengths), generator=generator) 94 | megabatch_size = world_size * batch_size 95 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)] 96 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches] 97 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches] 98 | 99 | return [i for megabatch in megabatches for batch in megabatch for i in batch] 100 | 101 | 102 | class LengthGroupedSampler(Sampler): 103 | r""" 104 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while 105 | keeping a bit of randomness. 106 | """ 107 | 108 | def __init__( 109 | self, 110 | batch_size: int, 111 | world_size: int, 112 | lengths: Optional[List[int]] = None, 113 | generator=None, 114 | group_by_modality: bool = False, 115 | ): 116 | if lengths is None: 117 | raise ValueError("Lengths must be provided.") 118 | 119 | self.batch_size = batch_size 120 | self.world_size = world_size 121 | self.lengths = lengths 122 | self.generator = generator 123 | self.group_by_modality = group_by_modality 124 | 125 | def __len__(self): 126 | return len(self.lengths) 127 | 128 | def __iter__(self): 129 | if self.group_by_modality: 130 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 131 | else: 132 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator) 133 | return iter(indices) 134 | 135 | 136 | class SetokimTrainer(Trainer): 137 | 138 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 139 | if self.train_dataset is None or not has_length(self.train_dataset): 140 | return None 141 | 142 | if self.args.group_by_modality_length: 143 | lengths = self.train_dataset.modality_lengths 144 | return LengthGroupedSampler( 145 | self.args.train_batch_size, 146 | world_size=self.args.world_size * self.args.gradient_accumulation_steps, 147 | lengths=lengths, 148 | group_by_modality=True, 149 | ) 150 | else: 151 | return super()._get_train_sampler() 152 | 153 | def create_optimizer(self): 154 | """ 155 | Setup the optimizer. 156 | 157 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 158 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 159 | """ 160 | if is_sagemaker_mp_enabled(): 161 | return super().create_optimizer() 162 | 163 | opt_model = self.model 164 | 165 | if self.optimizer is None: 166 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 167 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 168 | if self.args.mm_in_projector_lr is not None: 169 | projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_in_projector" in name] 170 | if self.args.mm_out_projector_lr is not None: 171 | projector_parameters.extend([name for name, _ in opt_model.named_parameters() if "mm_out_projector" in name ]) 172 | optimizer_grouped_parameters = [ 173 | { 174 | "params": [ 175 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad) 176 | ], 177 | "weight_decay": self.args.weight_decay, 178 | }, 179 | { 180 | "params": [ 181 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad) 182 | ], 183 | "weight_decay": 0.0, 184 | }, 185 | { 186 | "params": [ 187 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad) 188 | ], 189 | "weight_decay": self.args.weight_decay, 190 | "lr": self.args.mm_projector_lr, 191 | }, 192 | { 193 | "params": [ 194 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad) 195 | ], 196 | "weight_decay": 0.0, 197 | "lr": self.args.mm_projector_lr, 198 | }, 199 | ] 200 | else: 201 | optimizer_grouped_parameters = [ 202 | { 203 | "params": [ 204 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) 205 | ], 206 | "weight_decay": self.args.weight_decay, 207 | }, 208 | { 209 | "params": [ 210 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) 211 | ], 212 | "weight_decay": 0.0, 213 | }, 214 | ] 215 | 216 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 217 | 218 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 219 | if optimizer_cls.__name__ == "Adam8bit": 220 | import bitsandbytes 221 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance() 222 | 223 | skipped = 0 224 | for module in opt_model.modules(): 225 | if isinstance(module, nn.Embedding): 226 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) 227 | logger.info(f"skipped {module}: {skipped/2**20}M params") 228 | manager.register_module_override(module, "weight", {"optim_bits": 32}) 229 | logger.debug(f"bitsandbytes: will optimize {module} in fp32") 230 | logger.info(f"skipped: {skipped/2**20}M params") 231 | 232 | return self.optimizer 233 | 234 | def _save_checkpoint(self, model, trial, metrics=None): 235 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 236 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR 237 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" 238 | 239 | run_dir = self._get_output_dir(trial=trial) 240 | output_dir = os.path.join(run_dir, checkpoint_folder) 241 | 242 | # Only save Adapter 243 | keys_to_match = ['mm_in_projector', 'mm_out_projector'] 244 | if getattr(self.args, "use_im_start_end", False): 245 | keys_to_match.extend(['embed_tokens', 'embed_in']) 246 | 247 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match) 248 | 249 | if self.args.local_rank == 0 or self.args.local_rank == -1: 250 | self.model.config.save_pretrained(output_dir) 251 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin')) 252 | else: 253 | super(SetokimTrainer, self)._save_checkpoint(model, trial, metrics) 254 | 255 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 256 | if getattr(self.args, 'tune_mm_mlp_adapter', False): 257 | pass 258 | else: 259 | super(SetokimTrainer, self)._save(output_dir, state_dict) -------------------------------------------------------------------------------- /src/train/train_mem.py: -------------------------------------------------------------------------------- 1 | 2 | if __name__ == "__main__": 3 | # training setokim 4 | # from train_setokim import train 5 | # train(attn_implementation="flash_attention_2") 6 | 7 | # training setok 8 | from train_setok import train 9 | train(attn_implementation="flash_attention_2") 10 | -------------------------------------------------------------------------------- /src/train/train_setok.py: -------------------------------------------------------------------------------- 1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright: 2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright: 3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | import copy 19 | from dataclasses import dataclass, field 20 | import json 21 | import logging 22 | import pathlib 23 | from typing import Dict, Optional, Sequence, List 24 | import torch 25 | import transformers 26 | import tokenizers 27 | from torch.utils.data import Dataset 28 | from src.model import SeTok 29 | from src.dataset import * 30 | from setok_trainer import SetokTrainer 31 | from training_utils import * 32 | 33 | 34 | 35 | def rank0_print(*args): 36 | if local_rank == 0: 37 | print(*args) 38 | 39 | 40 | from packaging import version 41 | IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14') 42 | 43 | 44 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 45 | """Collects the state dict and dump to disk.""" 46 | 47 | trainer.model.config.save_pretrained(output_dir) 48 | 49 | current_folder = output_dir.split('/')[-1] 50 | parent_folder = os.path.dirname(output_dir) 51 | if trainer.args.local_rank == 0 or trainer.args.local_rank == -1: 52 | if current_folder.startswith('checkpoint-'): 53 | mm_projector_folder = os.path.join(parent_folder, "mm_projector") 54 | os.makedirs(mm_projector_folder, exist_ok=True) 55 | torch.save(trainer.model.state_dict(), os.path.join(mm_projector_folder, f'{current_folder}.bin')) 56 | else: 57 | torch.save(trainer.model.state_dict(), os.path.join(output_dir, f'mm_projector.bin')) 58 | return 59 | 60 | 61 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, 62 | data_args, 63 | siglipTokenizer: Optional[transformers.PreTrainedTokenizer]=None, 64 | ) -> Dict: 65 | """Make dataset and collator for supervised fine-tuning.""" 66 | train_dataset = TextImagePairDataset( 67 | tokenizer=tokenizer, 68 | data_path=data_args.data_path, 69 | data_args=data_args, 70 | constrative_tokenizer=siglipTokenizer 71 | ) 72 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, siglipTokenizer=siglipTokenizer) 73 | return dict(train_dataset=train_dataset, 74 | eval_dataset=None, 75 | data_collator=data_collator) 76 | 77 | 78 | 79 | def train(attn_implementation=None): 80 | global local_rank 81 | parser = transformers.HfArgumentParser( 82 | (ModelArguments, DataArguments, TrainingArguments, VisionTowerArguments, VisionGeneratorArguments, ReconstructionLossArguments, ConstrastiveLossArguments)) 83 | model_args, data_args, training_args, vision_tower_args, vision_generator_args, rec_loss_args, constrative_loss_args = parser.parse_args_into_dataclasses() 84 | local_rank = training_args.local_rank 85 | compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) 86 | 87 | print('model_args:', model_args) 88 | print('data_args: ', data_args) 89 | 90 | model = SeTok(tokenizer_config=VisionTowerArguments, 91 | detokenizer_config=VisionGeneratorArguments, 92 | rec_loss_config=rec_loss_args, 93 | contrastive_loss_config=constrative_loss_args) 94 | 95 | tokenizer = transformers.AutoTokenizer.from_pretrained( 96 | model_args.model_name_or_path, 97 | cache_dir=training_args.cache_dir, 98 | model_max_length=training_args.model_max_length, 99 | padding_side="right", 100 | use_fast=False, 101 | ) 102 | 103 | siglipTokenizer = transformers.AutoTokenizer.from_pretrained( 104 | vision_tower_args.vision_tower, 105 | cache_dir=training_args.cache_dir, 106 | padding_side="right", 107 | use_fast=False, 108 | ) 109 | 110 | 111 | data_module = make_supervised_data_module(tokenizer=tokenizer, 112 | data_args=data_args, 113 | siglipTokenizer=siglipTokenizer) 114 | trainer = SetokTrainer(model=model, 115 | tokenizer=tokenizer, 116 | args=training_args, 117 | **data_module) 118 | 119 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 120 | trainer.train(resume_from_checkpoint=True) 121 | else: 122 | trainer.train() 123 | trainer.save_state() 124 | model.config.use_cache = True 125 | safe_save_model_for_hf_trainer(trainer=trainer, 126 | output_dir=training_args.output_dir) 127 | 128 | 129 | if __name__ == "__main__": 130 | train() 131 | -------------------------------------------------------------------------------- /src/train/train_xformers.py: -------------------------------------------------------------------------------- 1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention. 2 | 3 | # Need to call this before importing transformers. 4 | from src.train.llama_xformers_attn_monkey_patch import ( 5 | replace_llama_attn_with_xformers_attn, 6 | ) 7 | 8 | replace_llama_attn_with_xformers_attn() 9 | 10 | from src.train.train import train 11 | 12 | if __name__ == "__main__": 13 | train() 14 | -------------------------------------------------------------------------------- /src/train/training_utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Dict, Optional, Sequence, List, Union 3 | import transformers 4 | 5 | 6 | @dataclass 7 | class ModelArguments: 8 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 9 | vision_tokenizer: Optional[str] = field(default='setok') 10 | vision_generator: Optional[str] = field(default='setok') 11 | diff_loss: Optional[str] = field(default='diff_loss') 12 | version: Optional[str] = field(default="v0") 13 | freeze_backbone: bool = field(default=False) 14 | mm_use_im_start_end: bool = field(default=True) 15 | mm_use_im_patch_token: bool = field(default=True) 16 | mm_patch_merge_type: Optional[str] = field(default='flat') 17 | 18 | 19 | @dataclass 20 | class VisionTowerArguments: 21 | vision_tower: Optional[str] = field(default='google/siglip-so400m-patch14-384') 22 | pretrain_vision_tokenizer: Optional[str] = field(default='') 23 | unfreeze_mm_vision_tower: bool = field(default=False) 24 | mm_vision_select_feature: Optional[str] = field(default="patch") 25 | mm_vision_select_layer: Optional[int] = field(default=-1) 26 | delay_load: bool = field(default=False) 27 | hidden_dim: Optional[int] = field(default=4096) 28 | token_feat_dim: Optional[int] = field(default=4096) 29 | min_cluster_num: Optional[int] = field(default=64) 30 | threshold: Optional[float] = field(default=0.55) 31 | nheads: Optional[int] = field(default=2) 32 | dim_feedforward: Optional[int] = field(default=4096) 33 | proj_drop: Optional[float] = field(default=0.2) 34 | attn_drop: Optional[float] = field(default=0.0) 35 | inner_cluster_layers: Optional[int] = field(default=2) 36 | intra_cluster_layers: Optional[int] = field(default=2) 37 | 38 | @dataclass 39 | class VisionInProjectionArguments: 40 | pretrain_mm_in_mlp_adapter: Optional[str] = field(default=None) 41 | mm_in_projector_type: Optional[str] = field(default='mlp') 42 | mm_hidden_size: Optional[int] = field(default=1052) 43 | hidden_size: Optional[int] = field(default=4096) 44 | 45 | @dataclass 46 | class VisionGeneratorArguments: 47 | pretrain_vision_detokenizer: Optional[str] = field(default='') 48 | patch_size: Optional[int] = field(default=14) 49 | out_image_size: Optional[int] = field(default=384) 50 | decoder_embed_dim: Optional[int] = field(default=4096) 51 | decoder_nheads: Optional[int] = field(default=8) 52 | decoder_depth: Optional[int] = field(default=16) 53 | mlp_ratio: Optional[int] = field(default=4.0) 54 | feature_mapper_path_or_name: Optional[str] = field(default="bert-base-uncased") 55 | num_hidden_layers: Optional[int] = field(default=6) 56 | cross_attention_freq: Optional[int] = field(default=2) 57 | initializer_range: Optional[float] = field(default=0.02) 58 | 59 | @dataclass 60 | class VisionOutProjectionArguments: 61 | pretrain_mm_out_mlp_adapter: Optional[str] = field(default=None) 62 | mm_out_projector_type: Optional[str] = field(default='mlp') 63 | 64 | 65 | @dataclass 66 | class ReconstructionLossArguments: 67 | disc_in_channels: Optional[int] = field(default=16) 68 | disc_num_layers: Optional[int] = field(default=2) 69 | disc_start: Optional[int] = field(default=5000) 70 | warm_up_end: Optional[int] = field(default=200) 71 | 72 | @dataclass 73 | class ConstrastiveLossArguments: 74 | text_encoder: Optional[str] = field(default='google/siglip-so400m-patch14-384') 75 | contrast_temperature: Optional[float] = field(default=0.07) 76 | multi_label: Optional[int] = field(default=0) 77 | share_temperature: bool = field(default=False) 78 | multi_label_loss_weight: Optional[float] = field(default=1.0) 79 | 80 | @dataclass 81 | class DiffLossArguments: 82 | diffloss_w: Optional[int] = field(default=3) 83 | diffloss_d: Optional[int] = field(default=1024) 84 | num_sampling_steps: Optional[str] = field(default='100') 85 | grad_checkpointing: bool = field(default=False) 86 | diffusion_batch_mul: Optional[int] = field(default=4) 87 | mask_ratio_min: Optional[float] = field(default=0.7) 88 | 89 | 90 | @dataclass 91 | class DataArguments: 92 | data_path: Union[List[str], str] = field(default=None, metadata={"help": "Path to the training data."}) 93 | dataset_name: Union[List[str], str] = field(default=None) 94 | data_multiple: Union[List[float], str] = field(default=None, metadata={"help": "Data mutliplier for each dataset when mixed. None means direct concat."}) 95 | lazy_preprocess: bool = False 96 | is_multimodal: bool = False 97 | image_folder: Union[List[str], str] = field(default=None) 98 | image_size: Optional[int] = field(default=448) 99 | image_aspect_ratio: str = 'square' 100 | task_type: Optional[str] = 'instruction' 101 | 102 | 103 | @dataclass 104 | class TrainingArguments(transformers.TrainingArguments): 105 | cache_dir: Optional[str] = field(default=None) 106 | optim: str = field(default="adamw_torch") 107 | remove_unused_columns: bool = field(default=False) 108 | tune_mm_in_mlp_adapter: bool = field(default=False) 109 | tune_mm_out_mlp_adapter: bool = field(default=False) 110 | freeze_mm_in_mlp_adapter: bool = field(default=False) 111 | freeze_mm_out_mlp_adapter: bool = field(default=False) 112 | mpt_attn_impl: Optional[str] = field(default="triton") 113 | model_max_length: int = field( 114 | default=512, 115 | metadata={ 116 | "help": 117 | "Maximum sequence length. Sequences will be right padded (and possibly truncated)." 118 | }, 119 | ) 120 | double_quant: bool = field( 121 | default=True, 122 | metadata={"help": "Compress the quantization statistics through double quantization."} 123 | ) 124 | quant_type: str = field( 125 | default="nf4", 126 | metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."} 127 | ) 128 | bits: int = field( 129 | default=16, 130 | metadata={"help": "How many bits to use."} 131 | ) 132 | lora_enable: bool = False 133 | lora_r: int = 64 134 | lora_alpha: int = 16 135 | lora_dropout: float = 0.05 136 | lora_weight_path: str = "" 137 | lora_bias: str = "none" 138 | mm_in_projector_lr: Optional[float] = None 139 | mm_out_projector_lr: Optional[float] = None 140 | group_by_modality_length: bool = field(default=False) 141 | 142 | 143 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import logging.handlers 4 | import os 5 | import sys 6 | import math 7 | import random 8 | 9 | import requests 10 | 11 | from src.constants import LOGDIR 12 | 13 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" 14 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." 15 | 16 | handler = None 17 | 18 | 19 | def build_logger(logger_name, logger_filename): 20 | global handler 21 | 22 | formatter = logging.Formatter( 23 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", 24 | datefmt="%Y-%m-%d %H:%M:%S", 25 | ) 26 | 27 | # Set the format of root handlers 28 | if not logging.getLogger().handlers: 29 | logging.basicConfig(level=logging.INFO) 30 | logging.getLogger().handlers[0].setFormatter(formatter) 31 | 32 | # Redirect stdout and stderr to loggers 33 | stdout_logger = logging.getLogger("stdout") 34 | stdout_logger.setLevel(logging.INFO) 35 | sl = StreamToLogger(stdout_logger, logging.INFO) 36 | sys.stdout = sl 37 | 38 | stderr_logger = logging.getLogger("stderr") 39 | stderr_logger.setLevel(logging.ERROR) 40 | sl = StreamToLogger(stderr_logger, logging.ERROR) 41 | sys.stderr = sl 42 | 43 | # Get logger 44 | logger = logging.getLogger(logger_name) 45 | logger.setLevel(logging.INFO) 46 | 47 | # Add a file handler for all loggers 48 | if handler is None: 49 | os.makedirs(LOGDIR, exist_ok=True) 50 | filename = os.path.join(LOGDIR, logger_filename) 51 | handler = logging.handlers.TimedRotatingFileHandler( 52 | filename, when='D', utc=True, encoding='UTF-8') 53 | handler.setFormatter(formatter) 54 | 55 | for name, item in logging.root.manager.loggerDict.items(): 56 | if isinstance(item, logging.Logger): 57 | item.addHandler(handler) 58 | 59 | return logger 60 | 61 | 62 | class StreamToLogger(object): 63 | """ 64 | Fake file-like stream object that redirects writes to a logger instance. 65 | """ 66 | def __init__(self, logger, log_level=logging.INFO): 67 | self.terminal = sys.stdout 68 | self.logger = logger 69 | self.log_level = log_level 70 | self.linebuf = '' 71 | 72 | def __getattr__(self, attr): 73 | return getattr(self.terminal, attr) 74 | 75 | def write(self, buf): 76 | temp_linebuf = self.linebuf + buf 77 | self.linebuf = '' 78 | for line in temp_linebuf.splitlines(True): 79 | # From the io.TextIOWrapper docs: 80 | # On output, if newline is None, any '\n' characters written 81 | # are translated to the system default line separator. 82 | # By default sys.stdout.write() expects '\n' newlines and then 83 | # translates them so this is still cross platform. 84 | if line[-1] == '\n': 85 | self.logger.log(self.log_level, line.rstrip()) 86 | else: 87 | self.linebuf += line 88 | 89 | def flush(self): 90 | if self.linebuf != '': 91 | self.logger.log(self.log_level, self.linebuf.rstrip()) 92 | self.linebuf = '' 93 | 94 | 95 | def disable_torch_init(): 96 | """ 97 | Disable the redundant torch default initialization to accelerate model creation. 98 | """ 99 | import torch 100 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None) 101 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) 102 | 103 | 104 | def violates_moderation(text): 105 | """ 106 | Check whether the text violates OpenAI moderation API. 107 | """ 108 | url = "https://api.openai.com/v1/moderations" 109 | headers = {"Content-Type": "application/json", 110 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} 111 | text = text.replace("\n", "") 112 | data = "{" + '"input": ' + f'"{text}"' + "}" 113 | data = data.encode("utf-8") 114 | try: 115 | ret = requests.post(url, headers=headers, data=data, timeout=5) 116 | flagged = ret.json()["results"][0]["flagged"] 117 | except requests.exceptions.RequestException as e: 118 | flagged = False 119 | except KeyError as e: 120 | flagged = False 121 | 122 | return flagged 123 | 124 | 125 | def pretty_print_semaphore(semaphore): 126 | if semaphore is None: 127 | return "None" 128 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})" 129 | 130 | 131 | 132 | def extend_list(original_list, multiplier): 133 | # Calculate how many elements to replicate and how many to select randomly 134 | replicate_elements = math.floor(multiplier) 135 | random_elements = multiplier - replicate_elements 136 | 137 | # Replicate the list 138 | replicated_list = original_list * replicate_elements 139 | 140 | # Calculate how many elements to randomly select 141 | select_elements = math.ceil(len(original_list) * random_elements) 142 | 143 | # Randomly select elements and append to the replicated list 144 | for _ in range(select_elements): 145 | random_element = random.choice(original_list) 146 | replicated_list.append(random_element) 147 | 148 | return replicated_list --------------------------------------------------------------------------------