├── LICENSE
├── README.md
├── assets
├── framework.jpeg
├── intro.jpeg
└── setok.png
├── cog.yaml
├── pyproject.toml
├── scripts
├── extract_mm_projector.py
├── finetune.sh
├── merge_lora_weights.py
├── pretrain_mm_proj.sh
├── train_setok.sh
├── zero2.json
├── zero3.json
└── zero3_offload.json
└── src
├── __init__.py
├── constants.py
├── conversation.py
├── data_preprocess.py
├── dataset
├── __init__.py
├── base_dataset.py
├── dataset_utils.py
├── editDataset.py
├── instructDataset.py
├── pairDataset.py
└── vqa.py
├── mm_utils.py
├── model
├── __init__.py
├── apply_delta.py
├── builder.py
├── consolidate.py
├── diffusion
│ ├── __init__.py
│ ├── diffusion_utils.py
│ ├── gaussian_diffusion.py
│ └── respace.py
├── language_model
│ └── setokim_llama.py
├── loss
│ ├── __init__.py
│ ├── diffloss.py
│ ├── discriminator.py
│ ├── mse.py
│ ├── multilabel_constrastive.py
│ ├── perceptual.py
│ └── segmentation.py
├── make_delta.py
├── multimodal_encoder
│ ├── builder.py
│ ├── clip_encoder.py
│ ├── eva_encoder.py
│ ├── openclip_encoder.py
│ └── openclip_processor.py
├── multimodal_generator
│ └── builder.py
├── multimodal_projector
│ └── builder.py
├── setok
│ ├── __init__.py
│ ├── clip_encoder.py
│ ├── detokenizer.py
│ ├── model.py
│ ├── module.py
│ ├── tokenizer.py
│ └── utils.py
├── setokim_arch.py
└── utils.py
├── train
├── llama_flash_attn_monkey_patch.py
├── llama_xformers_attn_monkey_patch.py
├── setok_trainer.py
├── setokim_trainer.py
├── train_mem.py
├── train_setok.py
├── train_setokim.py
├── train_xformers.py
└── training_utils.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
Towards Semantic Equivalence of Tokenization in Multimodal LLM
5 |
6 | Shengqiong Wu
7 | ·
8 | Hao Fei
9 | ·
10 | Xiangtai Li
11 | ·
12 | Jiayi Ji
13 | ·
14 |
15 | Hanwang Zhang
16 | ·
17 | Tat-seng Chua
18 | ·
19 | Shuicheng Yan
20 |
21 |
22 | National University of Singapore · Skywork AI, Singapore ·
23 |
Nanyang Technological University
24 |
25 |
26 | Work is done as an intern in Skywork AI, Hao Fei is the corresponding author.
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 | 
38 |
39 |
40 | ### Abstract
41 |
42 | Multimodal Large Language Models (MLLMs) have demonstrated exceptional capabilities in processing vision-language tasks. One of the crux of MLLMs lies in vision tokenization, which involves efficiently transforming input visual signals into feature representations that are most beneficial for LLMs. However, existing vision tokenizers, essential for semantic alignment between vision and language, remain problematic. Existing methods aggressively fragment visual input, corrupting the visual semantic integrity. To address this, this paper proposes a novel dynamic `Semantic-Equivalent Vision Tokenizer` (**SeTok**), which groups visual features into semantic units via a dynamic clustering algorithm, flexibly determining the number of tokens based on image complexity. The resulting vision tokens effectively preserve semantic integrity and capture both low-frequency and high-frequency visual features.
43 | 
44 |
45 | The proposed MLLM (**Setokim**) equipped with SeTok significantly demonstrates superior performance across various tasks, as evidenced by our experimental results.
46 |
47 |

48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 | ## 📂 Getting Start
57 |
58 | First, prepare the code and set up the environment:
59 | ```
60 | # download the code
61 | git clone https://github.com/ChocoWu/SeTok.git
62 | cd SeTok
63 |
64 | # install
65 | conda creat -n setok python=3.10 -y
66 | conda activate setok
67 | pip install --upgrade pip
68 | pip install -e .
69 | ```
70 |
71 | ### 📚 Preparing Data
72 |
73 | To begin, please prepare the dataset. All datasets should be placed under the [`data/`](./data) directory, with the following structure:
74 | ```
75 | data
76 | - ImageNet-1K
77 | - data.json
78 | - images
79 | - xxx.jpg
80 | - OpenImages
81 | - ALLaVA
82 | - GQA
83 | - data.json
84 | - images
85 | - xxx.jpg
86 | - OK-VQA
87 | - ...
88 | - InstructPix2Pix
89 | - Magicbrush
90 | ```
91 | For details on how each dataset is processed, please refer to the following scripts:
92 |
93 | - [`TextImagePairDataset`](./src/dataset/pairDataset.py)
94 | - [`EditingDataset`](./src/dataset/editDataset.py)
95 | - [`InstructionTuningDataset`](./src/dataset/instructDataset.py)
96 |
97 |
98 | ### 🚀 Training
99 | Our training receipts involves three stages.
100 |
101 | - **Stage-1: Setok tokenizer training**. We use ImageNet-1K for reconstruction learning and OpenImages for both reconstruction and alignment learning.
102 | ```
103 | # In [train_mem.py], activate Setok training:
104 |
105 | from train_setok import train
106 | train(attn_implementation="flash_attention_2")
107 |
108 | # Set the hyper-parameters in [train_setok.sh].
109 | bash train_setok.sh
110 | ```
111 | Make sure the dataset paths are correctly set in your config file or environment variables.
112 |
113 |
114 | - **Stage-2: Multimodal Pretraining**. In this stage, we focus on enhancing the alignment between text and image. We employ massive multimodal data, including ImageNet-1K and 28M text-image pair dataset, to train our model for conditional image generation and image captioning.
115 | ```
116 | # In [train_mem.py], activate Setok training:
117 |
118 | from train_setokim import train
119 | train(attn_implementation="flash_attention_2")
120 |
121 | # Set the hyper-parameters in [pretrain_mm_proj.sh].
122 | bash pretrain_mm_proj.sh
123 |
124 | ```
125 |
126 | - **Stage-3: Instruction Tuning**. Building upon the pretrained weights, we further perform multimodal instruction tuning with both public datasets covering
127 | multimodal instruction datasets, fine-grained visual QA, and etc.
128 | ```
129 | # Set the hyper-parameters in [finetune.sh].
130 | bash finetune.sh
131 | ```
132 |
133 |
134 |
135 | ## ✨ Citation
136 |
137 | If you use **SeTok** in your project, please kindly cite:
138 |
139 | ```bibtex
140 | @article{wu2024towards,
141 | title={Towards Semantic Equivalence of Tokenization in Multimodal LLM},
142 | author={Wu, Shengqiong and Fei, Hao and Li, Xiangtai and Ji, Jiayi and Zhang, Hanwang and Chua, Tat-Seng and Yan, Shuicheng},
143 | publisher={ICLR},
144 | year={2025}
145 | }
146 | ```
147 |
148 | ## Acknowledgments
149 |
150 | This work is heavily built based on [LLaVA](https://github.com/haotian-liu/LLaVA), [GroupViT](https://github.com/NVlabs/GroupViT), [MAR](https://github.com/LTH14/mar), [Blip-2](https://github.com/salesforce/LAVIS).
151 | Thanks to all the authors for their great work.
--------------------------------------------------------------------------------
/assets/framework.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChocoWu/SeTok/55cfc9fa5bdf28955f05bb627df3e300a5693afb/assets/framework.jpeg
--------------------------------------------------------------------------------
/assets/intro.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChocoWu/SeTok/55cfc9fa5bdf28955f05bb627df3e300a5693afb/assets/intro.jpeg
--------------------------------------------------------------------------------
/assets/setok.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChocoWu/SeTok/55cfc9fa5bdf28955f05bb627df3e300a5693afb/assets/setok.png
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | # Configuration for Cog ⚙️
2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md
3 |
4 | build:
5 | gpu: true
6 |
7 | python_version: "3.8"
8 |
9 | python_packages:
10 | - "torch==2.2.1"
11 | - "accelerate==0.27.2"
12 | - "bitsandbytes==0.41.0"
13 | - "deepspeed==0.13.6"
14 | - "einops-exts==0.0.4"
15 | - "einops==0.6.1"
16 | - "gradio==3.35.2"
17 | - "gradio_client==0.2.9"
18 | - "httpx==0.24.0"
19 | - "markdown2==2.4.10"
20 | - "numpy==1.24.4"
21 | - "peft==0.13.2"
22 | - "scikit-learn==1.3.2"
23 | - "sentencepiece==0.1.99"
24 | - "shortuuid==1.0.11"
25 | - "timm==0.9.16"
26 | - "tokenizers==0.20.3"
27 | - "torchvision==0.17.1"
28 | - "transformers==4.46.3"
29 | - "diffusers==0.27.2"
30 | - "scipy==1.10.1"
31 | - "wandb==0.15.12"
32 | - "wavedrom==2.0.3.post3"
33 | - "Pygments==2.16.1"
34 | run:
35 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.3/pget" && chmod +x /usr/local/bin/pget
36 |
37 | # predict.py defines how predictions are run on your model
38 | predict: "predict.py:Predictor"
39 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["setuptools>=61.0"]
3 | build-backend = "setuptools.build_meta"
4 |
5 | [project]
6 | name = "setok"
7 | version = "0.0.0"
8 | description = "Towards Semantic Equivalence of Tokenization in Multimodal LLM"
9 | readme = "README.md"
10 | requires-python = ">=3.8"
11 | classifiers = [
12 | "Programming Language :: Python :: 3",
13 | "License :: OSI Approved :: Apache Software License",
14 | ]
15 | dependencies = [
16 | "torch==2.2.1", "torchvision==0.17.1",
17 | "diffdist==0.1", "diffusers==0.27.2", "scipy==1.10.1",
18 | "transformers==4.46.3", "tokenizers==0.20.3", "sentencepiece==0.1.99", "shortuuid",
19 | "accelerate==0.27.2", "peft", "bitsandbytes",
20 | "pydantic", "markdown2[all]", "numpy", "scikit-learn==1.3.2",
21 | "requests", "httpx==0.24.0", "uvicorn", "fastapi",
22 | "einops==0.6.1", "einops-exts==0.0.4", "timm==0.9.16",
23 | ]
24 |
25 | [project.optional-dependencies]
26 | train = ["deepspeed==0.13.6", "ninja", "wandb"]
27 | build = ["build", "twine"]
28 |
29 | [project.urls]
30 | "Homepage" = "https://llava-vl.github.io"
31 | "Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues"
32 |
33 | [tool.setuptools.packages.find]
34 | exclude = ["assets*", "scripts*", "tests*"]
35 |
36 | [tool.wheel]
37 | exclude = ["assets*", "scripts*", "tests*"]
38 |
--------------------------------------------------------------------------------
/scripts/extract_mm_projector.py:
--------------------------------------------------------------------------------
1 | """
2 | This is just a utility that I use to extract the projector for quantized models.
3 | It is NOT necessary at all to train, or run inference/serve demos.
4 | Use this script ONLY if you fully understand its implications.
5 | """
6 |
7 |
8 | import os
9 | import argparse
10 | import torch
11 | import json
12 | from collections import defaultdict
13 |
14 |
15 | def parse_args():
16 | parser = argparse.ArgumentParser(description='Extract MMProjector weights')
17 | parser.add_argument('--model-path', type=str, help='model folder')
18 | parser.add_argument('--output', type=str, help='output file')
19 | args = parser.parse_args()
20 | return args
21 |
22 |
23 | if __name__ == '__main__':
24 | args = parse_args()
25 |
26 | keys_to_match = ['mm_in_projector', 'mm_out_projector']
27 | ckpt_to_key = defaultdict(list)
28 | try:
29 | model_indices = json.load(open(os.path.join(args.model_path, 'pytorch_model.bin.index.json')))
30 | for k, v in model_indices['weight_map'].items():
31 | if any(key_match in k for key_match in keys_to_match):
32 | ckpt_to_key[v].append(k)
33 | except FileNotFoundError:
34 | # Smaller models or model checkpoints saved by DeepSpeed.
35 | v = 'pytorch_model.bin'
36 | for k in torch.load(os.path.join(args.model_path, v), map_location='cpu').keys():
37 | if any(key_match in k for key_match in keys_to_match):
38 | ckpt_to_key[v].append(k)
39 |
40 | loaded_weights = {}
41 |
42 | for ckpt_name, weight_keys in ckpt_to_key.items():
43 | ckpt = torch.load(os.path.join(args.model_path, ckpt_name), map_location='cpu')
44 | for k in weight_keys:
45 | loaded_weights[k] = ckpt[k]
46 |
47 | torch.save(loaded_weights, args.output)
48 |
--------------------------------------------------------------------------------
/scripts/finetune.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | IMAGE_FOLDER=(
5 | "./data/ALLaVA-4V/"
6 | "./data/llava-150k/images"
7 | "./data/okvqa/train2014"
8 | "./data/okvqa/train2014"
9 | "./data/okvqa/train2014"
10 | "./data/gqa/images"
11 | )
12 |
13 | DATA_PATH=(
14 | "./data/ALLaVA-4V/allava_laion/ALLaVA-Instruct-LAION-4V_preprocessed.json"
15 | "./data/llava-150k/pandagpt4_visual_instruction_data.json"
16 | "./data/vqa2"
17 | "./data/okvqa"
18 | "./data/okvqa/aokvqa_v1p0_train.json"
19 | "./data/gqa/train_balanced_questions.json"
20 | ""
21 | )
22 |
23 | DATA_MULTIPLE=(
24 | 1
25 | 1
26 | 1
27 | 1
28 | 1
29 | 1
30 | )
31 |
32 | DATASET_NAME=(
33 | "ALLaVA-Instruct-LAION-4V"
34 | "LLaVA150K"
35 | "VQAv2"
36 | "OKVQA"
37 | "AOKVQA"
38 | "GQA"
39 | )
40 |
41 |
42 | IMAGE_FOLDER="${IMAGE_FOLDER[@]}"
43 | DATA_PATH="${DATA_PATH[@]}"
44 | DATASET_NAME="${DATASET_NAME[*]}"
45 | DATA_MULTIPLE="${DATA_MULTIPLE[@]}"
46 |
47 |
48 | deepspeed train_mem.py \
49 | --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
50 | --deepspeed ./scripts/zero2.json \
51 | --model_name_or_path ./pretrained_ckpt/vicuna-7b-v1.5 \
52 | --version v1 \
53 | --data_path $DATA_PATH \
54 | --dataset_name $DATASET_NAME \
55 | --image_folder $IMAGE_FOLDER \
56 | --data_multiple $DATA_MULTIPLE \
57 | --vision_tokenizer setok \
58 | --vision_tower ./pretrained_ckpt/siglip-so400m-patch14-384 \
59 | --pretrain_vision_tokenizer ./checkpoints/ \
60 | --pretrain_vision_detokenizer ./checkpoints/ \
61 | --mm_in_projector_type mlp2x_gelu \
62 | --tune_mm_in_mlp_adapter False \
63 | --pretrain_mm_in_mlp_adapter mm_projector.bin \
64 | --mm_out_projector_type mlp2x_gelu \
65 | --tune_mm_out_mlp_adapter True \
66 | --pretrain_mm_out_mlp_adapter mm_projector.bin \
67 | --mm_vision_select_layer -1 \
68 | --mm_use_im_start_end True \
69 | --mm_use_im_patch_token False \
70 | --feature_mapper_path_or_name ./pretrained_ckpt/bert-base-uncased \
71 | --bf16 True \
72 | --output_dir ./checkpoints/ \
73 | --num_train_epochs 1 \
74 | --per_device_train_batch_size 32 \
75 | --per_device_eval_batch_size 4 \
76 | --gradient_accumulation_steps 1 \
77 | --evaluation_strategy "no" \
78 | --save_strategy "steps" \
79 | --save_steps 4000 \
80 | --save_total_limit 1 \
81 | --learning_rate 1e-3 \
82 | --weight_decay 0. \
83 | --warmup_ratio 0.03 \
84 | --lr_scheduler_type "cosine" \
85 | --logging_steps 1 \
86 | --tf32 True \
87 | --model_max_length 2048 \
88 | --gradient_checkpointing True \
89 | --dataloader_num_workers 4 \
90 | --lazy_preprocess True \
91 | --report_to tensorboard
92 |
--------------------------------------------------------------------------------
/scripts/merge_lora_weights.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from src.model.builder import load_pretrained_model
3 | from src.mm_utils import get_model_name_from_path
4 |
5 |
6 | def merge_lora(args):
7 | model_name = get_model_name_from_path(args.model_path)
8 | tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map='cpu')
9 |
10 | model.save_pretrained(args.save_model_path)
11 | tokenizer.save_pretrained(args.save_model_path)
12 |
13 |
14 | if __name__ == "__main__":
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--model-path", type=str, required=True)
17 | parser.add_argument("--model-base", type=str, required=True)
18 | parser.add_argument("--save-model-path", type=str, required=True)
19 |
20 | args = parser.parse_args()
21 |
22 | merge_lora(args)
23 |
--------------------------------------------------------------------------------
/scripts/pretrain_mm_proj.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | IMAGE_FOLDER=(
5 | "./data/LLaVA-Instruct-150K/images"
6 | )
7 |
8 | DATA_PATH=(
9 | "./data/LLaVA-Instruct-150K/llava_instruct_150k.json"
10 | )
11 |
12 | DATA_MULTIPLE=(
13 | 1
14 | )
15 |
16 | DATASET_NAME=(
17 | "LLaVA150K"
18 | )
19 |
20 |
21 | IMAGE_FOLDER="${IMAGE_FOLDER[@]}"
22 | DATA_PATH="${DATA_PATH[@]}"
23 | DATASET_NAME="${DATASET_NAME[*]}"
24 | DATA_MULTIPLE="${DATA_MULTIPLE[@]}"
25 |
26 |
27 | accelerate launch train_mem.py \
28 | --deepspeed ./scripts/zero2.json \
29 | --model_name_or_path ./pretrained_ckpt/vicuna-7b-v1.5 \
30 | --version plain \
31 | --data_path $DATA_PATH \
32 | --dataset_name $DATASET_NAME \
33 | --image_folder $IMAGE_FOLDER \
34 | --data_multiple $DATA_MULTIPLE \
35 | --vision_tokenizer setok \
36 | --vision_tower ./pretrained_ckpt/siglip-so400m-patch14-384 \
37 | --pretrain_vision_tokenizer ./checkpoints/ \
38 | --pretrain_vision_detokenizer ./checkpoints/ \
39 | --mm_in_projector_type mlp2x_gelu \
40 | --tune_mm_in_mlp_adapter True \
41 | --mm_out_projector_type mlp2x_gelu \
42 | --tune_mm_out_mlp_adapter True \
43 | --mm_vision_select_layer -1 \
44 | --mm_use_im_start_end True \
45 | --mm_use_im_patch_token False \
46 | --feature_mapper_path_or_name ./pretrained_ckpt/bert-base-uncased \
47 | --bf16 True \
48 | --output_dir ./checkpoints/ \
49 | --num_train_epochs 1 \
50 | --per_device_train_batch_size 32 \
51 | --per_device_eval_batch_size 4 \
52 | --gradient_accumulation_steps 1 \
53 | --evaluation_strategy "no" \
54 | --save_strategy "steps" \
55 | --save_steps 4000 \
56 | --save_total_limit 1 \
57 | --learning_rate 1e-3 \
58 | --weight_decay 0. \
59 | --warmup_ratio 0.03 \
60 | --lr_scheduler_type "cosine" \
61 | --logging_steps 1 \
62 | --tf32 True \
63 | --model_max_length 2048 \
64 | --gradient_checkpointing True \
65 | --dataloader_num_workers 4 \
66 | --lazy_preprocess True \
67 | --report_to tensorboard
68 |
--------------------------------------------------------------------------------
/scripts/train_setok.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 |
6 | IMAGE_FOLDER=(
7 | "./data/cc3m/images" # 5240031
8 | )
9 |
10 | DATA_PATH=(
11 | "./data/cc3m/cc3m.json"
12 | )
13 |
14 | DATASET_NAME=(
15 | "cc3m"
16 | )
17 |
18 | DATA_MULTIPLE=(
19 | 1
20 | )
21 |
22 | echo $IMAGE_FOLDER
23 | echo $DATA_PATH
24 | echo $DATASET_NAME
25 |
26 | IMAGE_FOLDER="${IMAGE_FOLDER[*]}"
27 | DATA_PATH="${DATA_PATH[*]}"
28 | DATASET_NAME="${DATASET_NAME[*]}"
29 | DATA_MULTIPLE="${DATA_MULTIPLE[*]}"
30 |
31 |
32 | echo $IMAGE_FOLDER
33 | echo $DATA_PATH
34 | echo $DATASET_NAME
35 |
36 |
37 |
38 | deepspeed train_mem.py \
39 | --deepspeed ./scripts/zero2.json \
40 | --lora_enable False \
41 | --data_path $DATA_PATH \
42 | --image_folder $IMAGE_FOLDER \
43 | --data_multiple $DATA_MULTIPLE \
44 | --dataset_name $DATASET_NAME \
45 | --image_size $IMAGE_SIZE \
46 | --vision_tower ./pretrained_ckpt/siglip-so400m-patch14-384 \
47 | --feature_mapper_path_or_name ./pretrained_ckpt/bert-base-uncased \
48 | --bf16 False \
49 | --output_dir ./checkpoints/ \
50 | --num_train_epochs 1 \
51 | --per_device_train_batch_size 24 \
52 | --per_device_eval_batch_size 4 \
53 | --gradient_accumulation_steps 1 \
54 | --evaluation_strategy "no" \
55 | --save_strategy "steps" \
56 | --save_steps 24000 \
57 | --save_total_limit 1 \
58 | --learning_rate 1e-3 \
59 | --weight_decay 0. \
60 | --warmup_ratio 0.03 \
61 | --lr_scheduler_type "cosine" \
62 | --logging_steps 1 \
63 | --tf32 True \
64 | --fp16 False \
65 | --model_max_length 77 \
66 | --gradient_checkpointing True \
67 | --dataloader_num_workers 4 \
68 | --lazy_preprocess True \
69 | --report_to tensorboard \
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 2,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto"
22 | }
23 | }
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "train_micro_batch_size_per_gpu": "auto",
14 | "train_batch_size": "auto",
15 | "gradient_accumulation_steps": "auto",
16 | "zero_optimization": {
17 | "stage": 3,
18 | "overlap_comm": true,
19 | "contiguous_gradients": true,
20 | "sub_group_size": 1e9,
21 | "reduce_bucket_size": "auto",
22 | "stage3_prefetch_bucket_size": "auto",
23 | "stage3_param_persistence_threshold": "auto",
24 | "stage3_max_live_parameters": 1e9,
25 | "stage3_max_reuse_distance": 1e9,
26 | "stage3_gather_16bit_weights_on_model_save": true
27 | }
28 | }
--------------------------------------------------------------------------------
/scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "scheduler": {
23 | "type": "WarmupLR",
24 | "params": {
25 | "warmup_min_lr": "auto",
26 | "warmup_max_lr": "auto",
27 | "warmup_num_steps": "auto"
28 | }
29 | },
30 | "zero_optimization": {
31 | "stage": 3,
32 | "offload_optimizer": {
33 | "device": "cpu",
34 | "pin_memory": true
35 | },
36 | "offload_param": {
37 | "device": "cpu",
38 | "pin_memory": true
39 | },
40 | "overlap_comm": true,
41 | "contiguous_gradients": true,
42 | "sub_group_size": 1e9,
43 | "reduce_bucket_size": "auto",
44 | "stage3_prefetch_bucket_size": "auto",
45 | "stage3_param_persistence_threshold": "auto",
46 | "stage3_max_live_parameters": 1e9,
47 | "stage3_max_reuse_distance": 1e9,
48 | "gather_16bit_weights_on_model_save": true
49 | },
50 | "gradient_accumulation_steps": "auto",
51 | "gradient_clipping": "auto",
52 | "train_batch_size": "auto",
53 | "train_micro_batch_size_per_gpu": "auto",
54 | "steps_per_print": 1e5,
55 | "wall_clock_breakdown": false
56 | }
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
1 | from .model.language_model.setokim_llama import SetokimLlamaForCausalLM
2 |
--------------------------------------------------------------------------------
/src/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "."
5 |
6 | # Model Constants
7 | IGNORE_INDEX = -100
8 | IMAGE_TOKEN_INDEX = -200
9 | DEFAULT_IMAGE_TOKEN = ""
10 | DEFAULT_IMAGE_PATCH_TOKEN = ""
11 | DEFAULT_IM_START_TOKEN = ""
12 | DEFAULT_IM_END_TOKEN = ""
13 | IMAGE_PLACEHOLDER = ""
14 |
15 | TARGET_TOKEN_INDEX = -300
16 | DEFAULT_TARGET_TOKEN = ""
--------------------------------------------------------------------------------
/src/data_preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from concurrent.futures import ThreadPoolExecutor
4 | from joblib import Parallel, delayed
5 |
6 | from tqdm import tqdm
7 |
8 |
9 | def check_image_exists(data, image_folder):
10 | """ Helper function to check if an image exists and return the data accordingly. """
11 | if os.path.exists(os.path.join(image_folder, data['image'])):
12 | return (True, data)
13 | else:
14 | return (False, data['image'])
15 |
16 |
17 | def preprocess(params):
18 | data_path, image_folder, save_path = params
19 |
20 | # Load data from JSON file
21 | with open(data_path, 'r') as f:
22 | datas = json.load(f)
23 |
24 | # Create a thread pool executor to check image existence in parallel
25 | new_datas = []
26 | unexisted_images = []
27 | with Parallel(n_jobs=50) as parallel:
28 | results = parallel(delayed(check_image_exists)(data, image_folder) for data in tqdm(datas, desc="Checking images"))
29 |
30 |
31 | # Separate the results into new data and unexisted images
32 | for result in results:
33 | exists, data = result
34 | if exists:
35 | new_datas.append(data)
36 | else:
37 | unexisted_images.append(data)
38 |
39 | # Save the filtered data back to a JSON file
40 | with open(save_path, 'w') as f:
41 | json.dump(new_datas, f, indent=4)
42 |
43 | # Print out the unexisted images
44 | print(f'Unexisted images: {unexisted_images}')
45 |
46 |
47 |
48 | if __name__ == "__main__":
49 | data_path = './ALLaVA-4V/allava_laion/ALLaVA-Instruct-LAION-4V.json'
50 | image_folder = './ALLaVA-4V'
51 | save_path = './ALLaVA-4V/allava_laion/ALLaVA-Instruct-LAION-4V_preprocessed.json'
52 |
53 | import torch
54 | pretrain_mm_mlp_adapter = './checkpoints/vicuna-v1.5-7b-convnext-pretrain/mm_projector.bin'
55 | mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu')
56 | def get_w(weights, keyword):
57 | return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
58 |
59 | res = get_w(mm_projector_weights, 'mm_projector')
--------------------------------------------------------------------------------
/src/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .editDataset import InstructPix2Pix_Dataset, MagicBrush_Dataset, EditingDataset
2 | from .pairDataset import TextImagePairDataset
3 | from .instructDataset import InstructionTuningDataset
4 | from .base_dataset import DataCollatorForSupervisedDataset
--------------------------------------------------------------------------------
/src/dataset/dataset_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import math
3 | import random
4 | from PIL import Image
5 |
6 |
7 | def extend_list(original_list, multiplier):
8 | # Calculate how many elements to replicate and how many to select randomly
9 | replicate_elements = math.floor(multiplier)
10 | random_elements = multiplier - replicate_elements
11 |
12 | # Replicate the list
13 | replicated_list = original_list * replicate_elements
14 |
15 | # Calculate how many elements to randomly select
16 | select_elements = math.ceil(len(original_list) * random_elements)
17 |
18 | # Randomly select elements and append to the replicated list
19 | for _ in range(select_elements):
20 | random_element = random.choice(original_list)
21 | replicated_list.append(random_element)
22 |
23 | return replicated_list
24 |
25 |
26 | def expand2square(pil_img, background_color):
27 | width, height = pil_img.size
28 | if width == height:
29 | return pil_img
30 | elif width > height:
31 | result = Image.new(pil_img.mode, (width, width), background_color)
32 | result.paste(pil_img, (0, (width - height) // 2))
33 | return result
34 | else:
35 | result = Image.new(pil_img.mode, (height, height), background_color)
36 | result.paste(pil_img, ((height - width) // 2, 0))
37 | return result
--------------------------------------------------------------------------------
/src/dataset/editDataset.py:
--------------------------------------------------------------------------------
1 | import pdb
2 | import token
3 |
4 | from datasets import load_from_disk
5 | import io
6 | import numpy as np
7 | from PIL import Image
8 | import random
9 | import torch
10 | from torchvision import transforms
11 | from torch.utils.data import Dataset, ConcatDataset
12 | from .base_dataset import *
13 | from .dataset_utils import expand2square
14 |
15 |
16 | def convert_to_np(image, resolution):
17 | image = image.convert("RGB")
18 | image = image.resize((resolution, resolution), resample=Image.Resampling.BICUBIC)
19 | return np.array(image).transpose(2, 0, 1)
20 |
21 |
22 | def load_img_for_generator(image, resolution):
23 | # image = Image.open(path).convert("RGB")
24 | # w, h = image.size
25 | # print(f"loaded input image of size ({w}, {h}) from {path}")
26 | # w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
27 | image = image.resize((resolution), resample=Image.Resampling.LANCZOS)
28 | image = np.array(image).astype(np.float32) / 255.0
29 | image = image.transpose(2, 0, 1)
30 | image = torch.from_numpy(image)
31 | return 2.*image - 1.
32 |
33 |
34 | def get_random_response():
35 | image_editing_responses = {
36 | "simple": [
37 | "Here you go.",
38 | "All set.",
39 | "Done.",
40 | "Here it is.",
41 | "Finished.",
42 | "Done. Let me know if it works!"
43 | ],
44 | "polite_professional": [
45 | "The image has been edited as requested. Please take a look.",
46 | "Here is the updated version of the image you asked for.",
47 | "Attached is the revised image. Let me know if everything looks good.",
48 | "I've completed the edits—feel free to review.",
49 | "Please find the edited image below. Let me know if you'd like any revisions."
50 | ],
51 | "casual_friendly": [
52 | "All done! Hope you like it.",
53 | "Tada 🎨 Let me know what you think!",
54 | "Voila! Here's your image.",
55 | "Here's the new version—check it out!",
56 | "Done and dusted 😎",
57 | "Boom! Updated and ready."
58 | ],
59 | "open_to_feedback": [
60 | "Let me know if you'd like to adjust anything else.",
61 | "Happy to make further edits if needed!",
62 | "If you need a different version, just say the word.",
63 | "Want to tweak anything? I've got you.",
64 | "Tell me if something needs changing!"
65 | ],
66 | "image_generation_context": [
67 | "Here is the image based on your description.",
68 | "The generated image is ready. Let me know if it matches your vision.",
69 | "Here's what I came up with—does this align with what you had in mind?",
70 | "Based on your prompt, this is the result. Happy to revise!"
71 | ]
72 | }
73 |
74 | all_responses = sum(image_editing_responses.values(), [])
75 |
76 | random_reply = random.choice(all_responses)
77 | return random_reply
78 |
79 |
80 |
81 | class EditingDataset(Dataset):
82 | def __init__(self, data_path, tokenizer, data_args) -> None:
83 | super().__init__()
84 |
85 | instructPix2Pix_dataset = InstructPix2Pix_Dataset(data_path[0], tokenizer=tokenizer, data_args=data_args)
86 | magicBruch_dataset = MagicBrush_Dataset(data_path=data_path[1], tokenizer=tokenizer, data_args=data_args)
87 | self.datasets = ConcatDataset([instructPix2Pix_dataset, magicBruch_dataset])
88 |
89 | def __len__(self):
90 | return self.datasets.__len__()
91 |
92 | def __getitem__(self, item):
93 | return self.datasets.__getitem__(item)
94 |
95 |
96 |
97 | # InstructPix2Pix dataset
98 | class InstructPix2Pix_Dataset(LazySupervisedDataset):
99 | '''
100 | according to InstructPix2Pix, the dataset can be used to train models to follow edit instructions.
101 | Edit instructions are available in the 'edit_prompt'. 'original_image' can be used with the 'edit_prompt' and 'edited_image' denotes the image after applying the 'edit_prompt' on the 'original_image'.
102 | "original_image" + "edited_image" + "edit_prompt"
103 | '''
104 | def __init__(self,
105 | data_path,
106 | tokenizer,
107 | data_args,
108 | ):
109 | super().__init__(data_args=data_args, data_path=data_path, tokenizer=tokenizer)
110 |
111 | # InstructPix2Pix Dataset path
112 | self.list_data_dict = load_from_disk(data_path)
113 | # 224, 256
114 | self.resolution_for_comp = data_args.image_size
115 | self.resolution_for_gen = data_args.resolution_sd
116 |
117 | # tokenizer
118 | self.tokenizer = tokenizer
119 |
120 |
121 | def __len__(self,):
122 | return len(self.list_data_dict)
123 |
124 | def __getitem__(self, i):
125 | # # {'original_image': , 'edited_image': , 'edit_prompt': 'make the leaves yellow'}
126 | sources = self.list_data_dict[i]
127 | if isinstance(i, int):
128 | sources = [sources]
129 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
130 | if 'image' in sources[0]:
131 | original_image_file = self.list_data_dict[i]['original_image']
132 | # image_folder = self.data_args.image_folder
133 | processor = self.data_args.image_processor
134 | # image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
135 | original_image = Image.open(io.BytesIO(original_image_file['bytes'])).convert('RGB')
136 | edited_image_file = self.list_data_dict[i]['edited_image']
137 | edited_image = Image.open(io.BytesIO(edited_image_file['bytes'])).convert('RGB')
138 | if self.data_args.image_aspect_ratio == 'pad':
139 | def expand2square(pil_img, background_color):
140 | width, height = pil_img.size
141 | if width == height:
142 | return pil_img
143 | elif width > height:
144 | result = Image.new(pil_img.mode, (width, width), background_color)
145 | result.paste(pil_img, (0, (width - height) // 2))
146 | return result
147 | else:
148 | result = Image.new(pil_img.mode, (height, height), background_color)
149 | result.paste(pil_img, ((height - width) // 2, 0))
150 | return result
151 | original_image = expand2square(original_image, tuple(int(x*255) for x in processor.image_mean))
152 | comp_image = processor.preprocess(original_image, return_tensors='pt')['pixel_values'][0]
153 | gen_image = load_img_for_generator(edited_image, self.resolution_for_gen)
154 | else:
155 | comp_image = processor.preprocess(original_image, return_tensors='pt')['pixel_values'][0]
156 | gen_image = load_img_for_generator(edited_image, self.resolution_for_gen)
157 |
158 | _source = {
159 | "id": i,
160 | "conversations": [
161 | {"from": "human", "value": "\n"+self.list_data_dict[i]["edit_prompt"]},
162 | {"from": "gpt", "value": "\n"+ get_random_response()},
163 | ]
164 | }
165 | sources = preprocess_multimodal(
166 | [_source],
167 | self.data_args,
168 | target_num=gen_image.shape[0])
169 | else:
170 | sources = [_source]
171 | data_dict = preprocess(
172 | sources,
173 | self.tokenizer,
174 | has_image=('image' in self.list_data_dict[i]))
175 | if isinstance(i, int):
176 | data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
177 |
178 | # image exist in the data
179 | if 'image' in self.list_data_dict[i]:
180 | data_dict['comp_image'] = comp_image
181 | data_dict['gen_image'] = gen_image
182 | elif self.data_args.is_multimodal:
183 | # image does not exist in the data, but the model is multimodal
184 | crop_size = self.data_args.image_processor.crop_size
185 | data_dict['comp_image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
186 | data_dict['gen_image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
187 |
188 | return data_dict
189 |
190 |
191 |
192 | # MagicBrush dataset
193 | class MagicBrush_Dataset(LazySupervisedDataset):
194 | '''
195 | according to MagicBrush, the dataset can be used to train models to follow edit instructions.
196 | Edit instructions are available in the 'instruction'. 'source_img' can be used with the 'instruction' and 'target_img' denotes the image after applying the 'instruction' on the 'source_img'.
197 | "source_img" + "target_img" + "instruction"
198 | Dataset({features: ['img_id', 'turn_index', 'source_img', 'mask_img', 'instruction', 'target_img'], num_rows: 8807})
199 | '''
200 | def __init__(self,
201 | data_path,
202 | tokenizer,
203 | data_args,
204 | ):
205 | super().__init__(data_path=data_path, tokenizer=tokenizer, data_args=data_args)
206 | # MagicBrush Dataset path
207 | # InstructPix2Pix Dataset path
208 | self.list_data_dict = load_from_disk(data_path)
209 | # 224, 256
210 | self.resolution_for_comp = data_args.image_size
211 | self.resolution_for_gen = data_args.resolution_sd
212 |
213 | # tokenizer
214 | self.tokenizer = tokenizer
215 |
216 | def __len__(self,):
217 | return len(self.list_data_dict)
218 |
219 | def __getitem__(self, i):
220 | # {'source_img': , 'target_img': , 'instruction': 'let the asparagus be replaced with sausages'}
221 |
222 | sources = self.list_data_dict[i]
223 | if isinstance(i, int):
224 | sources = [sources]
225 | assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
226 | if 'image' in sources[0]:
227 | original_image_file = self.list_data_dict[i]['source_img']
228 | # image_folder = self.data_args.image_folder
229 | processor = self.data_args.image_processor
230 | # image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
231 | original_image = Image.open(io.BytesIO(original_image_file['bytes'])).convert('RGB')
232 | edited_image_file = self.list_data_dict[i]['target_img']
233 | edited_image = Image.open(io.BytesIO(edited_image_file['bytes'])).convert('RGB')
234 | if self.data_args.image_aspect_ratio == 'pad':
235 | def expand2square(pil_img, background_color):
236 | width, height = pil_img.size
237 | if width == height:
238 | return pil_img
239 | elif width > height:
240 | result = Image.new(pil_img.mode, (width, width), background_color)
241 | result.paste(pil_img, (0, (width - height) // 2))
242 | return result
243 | else:
244 | result = Image.new(pil_img.mode, (height, height), background_color)
245 | result.paste(pil_img, ((height - width) // 2, 0))
246 | return result
247 | original_image = expand2square(original_image, tuple(int(x*255) for x in processor.image_mean))
248 | comp_image = processor.preprocess(original_image, return_tensors='pt')['pixel_values'][0]
249 | gen_image = load_img_for_generator(edited_image, self.resolution_for_gen)
250 | else:
251 | comp_image = processor.preprocess(original_image, return_tensors='pt')['pixel_values'][0]
252 | gen_image = load_img_for_generator(edited_image, self.resolution_for_gen)
253 |
254 | _source = {
255 | "id": i,
256 | "conversations": [
257 | {"from": "human", "value": "\n"+self.list_data_dict[i]["instruction"]},
258 | {"from": "gpt", "value": "\n"+ get_random_response()},
259 | ]
260 | }
261 | sources = preprocess_multimodal(
262 | [_source],
263 | self.data_args,
264 | target_num=gen_image.shape[0])
265 | else:
266 | sources = [_source]
267 | data_dict = preprocess(
268 | sources,
269 | self.tokenizer,
270 | has_image=('image' in self.list_data_dict[i]))
271 | if isinstance(i, int):
272 | data_dict = dict(input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0])
273 |
274 | # image exist in the data
275 | if 'image' in self.list_data_dict[i]:
276 | data_dict['comp_image'] = comp_image
277 | data_dict['gen_image'] = gen_image
278 | elif self.data_args.is_multimodal:
279 | # image does not exist in the data, but the model is multimodal
280 | crop_size = self.data_args.image_processor.crop_size
281 | data_dict['comp_image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
282 | data_dict['gen_image'] = torch.zeros(3, crop_size['height'], crop_size['width'])
283 |
284 | return data_dict
--------------------------------------------------------------------------------
/src/dataset/vqa.py:
--------------------------------------------------------------------------------
1 | __author__ = 'aagrawal'
2 | __version__ = '0.9'
3 |
4 | # Interface for accessing the VQA dataset.
5 |
6 | # This code is based on the code written by Tsung-Yi Lin for MSCOCO Python API available at the following link:
7 | # (https://github.com/pdollar/coco/blob/master/PythonAPI/pycocotools/coco.py).
8 |
9 | # The following functions are defined:
10 | # VQA - VQA class that loads VQA annotation file and prepares data structures.
11 | # getQuesIds - Get question ids that satisfy given filter conditions.
12 | # getImgIds - Get image ids that satisfy given filter conditions.
13 | # loadQA - Load questions and answers with the specified question ids.
14 | # showQA - Display the specified questions and answers.
15 | # loadRes - Load result file and create result object.
16 |
17 | # Help on each function can be accessed by: "help(COCO.function)"
18 |
19 | import json
20 | import datetime
21 | import copy
22 |
23 | class VQA:
24 | def __init__(self, annotation_file=None, question_file=None):
25 | """
26 | Constructor of VQA helper class for reading and visualizing questions and answers.
27 | :param annotation_file (str): location of VQA annotation file
28 | :return:
29 | """
30 | # load dataset
31 | self.dataset = {}
32 | self.questions = {}
33 | self.qa = {}
34 | self.qqa = {}
35 | self.imgToQA = {}
36 | if not annotation_file == None and not question_file == None:
37 | print('loading VQA annotations and questions into memory...')
38 | time_t = datetime.datetime.utcnow()
39 | dataset = json.load(open(annotation_file, 'r'))
40 | questions = json.load(open(question_file, 'r'))
41 | print(datetime.datetime.utcnow() - time_t)
42 | self.dataset = dataset
43 | self.questions = questions
44 | self.createIndex()
45 |
46 | def createIndex(self):
47 | # create index
48 | print('creating index...')
49 | imgToQA = {ann['image_id']: [] for ann in self.dataset['annotations']}
50 | qa = {ann['question_id']: [] for ann in self.dataset['annotations']}
51 | qqa = {ann['question_id']: [] for ann in self.dataset['annotations']}
52 | for ann in self.dataset['annotations']:
53 | imgToQA[ann['image_id']] += [ann]
54 | qa[ann['question_id']] = ann
55 | for ques in self.questions['questions']:
56 | qqa[ques['question_id']] = ques
57 | print('index created!')
58 |
59 | # create class members
60 | self.qa = qa
61 | self.qqa = qqa
62 | self.imgToQA = imgToQA
63 |
64 | def info(self):
65 | """
66 | Print information about the VQA annotation file.
67 | :return:
68 | """
69 | for key, value in self.datset['info'].items():
70 | print('%s: %s'%(key, value))
71 |
72 | def getQuesIds(self, imgIds=[], quesTypes=[], ansTypes=[]):
73 | """
74 | Get question ids that satisfy given filter conditions. default skips that filter
75 | :param imgIds (int array) : get question ids for given imgs
76 | quesTypes (str array) : get question ids for given question types
77 | ansTypes (str array) : get question ids for given answer types
78 | :return: ids (int array) : integer array of question ids
79 | """
80 | imgIds = imgIds if type(imgIds) == list else [imgIds]
81 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
82 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
83 |
84 | if len(imgIds) == len(quesTypes) == len(ansTypes) == 0:
85 | anns = self.dataset['annotations']
86 | else:
87 | if not len(imgIds) == 0:
88 | anns = sum([self.imgToQA[imgId] for imgId in imgIds if imgId in self.imgToQA],[])
89 | else:
90 | anns = self.dataset['annotations']
91 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
92 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
93 | ids = [ann['question_id'] for ann in anns]
94 | return ids
95 |
96 | def getImgIds(self, quesIds=[], quesTypes=[], ansTypes=[]):
97 | """
98 | Get image ids that satisfy given filter conditions. default skips that filter
99 | :param quesIds (int array) : get image ids for given question ids
100 | quesTypes (str array) : get image ids for given question types
101 | ansTypes (str array) : get image ids for given answer types
102 | :return: ids (int array) : integer array of image ids
103 | """
104 | quesIds = quesIds if type(quesIds) == list else [quesIds]
105 | quesTypes = quesTypes if type(quesTypes) == list else [quesTypes]
106 | ansTypes = ansTypes if type(ansTypes) == list else [ansTypes]
107 |
108 | if len(quesIds) == len(quesTypes) == len(ansTypes) == 0:
109 | anns = self.dataset['annotations']
110 | else:
111 | if not len(quesIds) == 0:
112 | anns = sum([self.qa[quesId] for quesId in quesIds if quesId in self.qa],[])
113 | else:
114 | anns = self.dataset['annotations']
115 | anns = anns if len(quesTypes) == 0 else [ann for ann in anns if ann['question_type'] in quesTypes]
116 | anns = anns if len(ansTypes) == 0 else [ann for ann in anns if ann['answer_type'] in ansTypes]
117 | ids = [ann['image_id'] for ann in anns]
118 | return ids
119 |
120 | def loadQA(self, ids=[]):
121 | """
122 | Load questions and answers with the specified question ids.
123 | :param ids (int array) : integer ids specifying question ids
124 | :return: qa (object array) : loaded qa objects
125 | """
126 | if type(ids) == list:
127 | return [self.qa[id] for id in ids]
128 | elif type(ids) == int:
129 | return [self.qa[ids]]
130 |
131 | def showQA(self, anns):
132 | """
133 | Display the specified annotations.
134 | :param anns (array of object): annotations to display
135 | :return: None
136 | """
137 | if len(anns) == 0:
138 | return 0
139 | for ann in anns:
140 | quesId = ann['question_id']
141 | print("Question: %s" %(self.qqa[quesId]['question']))
142 | for ans in ann['answers']:
143 | print("Answer %d: %s" %(ans['answer_id'], ans['answer']))
144 |
145 | def loadRes(self, resFile, quesFile):
146 | """
147 | Load result file and return a result object.
148 | :param resFile (str) : file name of result file
149 | :return: res (obj) : result api object
150 | """
151 | res = VQA()
152 | res.questions = json.load(open(quesFile))
153 | res.dataset['info'] = copy.deepcopy(self.questions['info'])
154 | res.dataset['task_type'] = copy.deepcopy(self.questions['task_type'])
155 | res.dataset['data_type'] = copy.deepcopy(self.questions['data_type'])
156 | res.dataset['data_subtype'] = copy.deepcopy(self.questions['data_subtype'])
157 | res.dataset['license'] = copy.deepcopy(self.questions['license'])
158 |
159 | print('Loading and preparing results... ')
160 | time_t = datetime.datetime.utcnow()
161 | anns = json.load(open(resFile))
162 | assert type(anns) == list, 'results is not an array of objects'
163 | annsQuesIds = [ann['question_id'] for ann in anns]
164 | assert set(annsQuesIds) == set(self.getQuesIds()), \
165 | 'Results do not correspond to current VQA set. Either the results do not have predictions for all question ids in annotation file or there is atleast one question id that does not belong to the question ids in the annotation file.'
166 | for ann in anns:
167 | quesId = ann['question_id']
168 | if res.dataset['task_type'] == 'Multiple Choice':
169 | assert ann['answer'] in self.qqa[quesId]['multiple_choices'], 'predicted answer is not one of the multiple choices'
170 | qaAnn = self.qa[quesId]
171 | ann['image_id'] = qaAnn['image_id']
172 | ann['question_type'] = qaAnn['question_type']
173 | ann['answer_type'] = qaAnn['answer_type']
174 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds()))
175 |
176 | res.dataset['annotations'] = anns
177 | res.createIndex()
178 | return res
--------------------------------------------------------------------------------
/src/mm_utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from io import BytesIO
3 | import base64
4 | import torch
5 | import math
6 | import ast
7 |
8 | from transformers import StoppingCriteria
9 | from src.constants import TARGET_TOKEN_INDEX, IMAGE_TOKEN_INDEX
10 |
11 |
12 | def select_best_resolution(original_size, possible_resolutions):
13 | """
14 | Selects the best resolution from a list of possible resolutions based on the original size.
15 |
16 | Args:
17 | original_size (tuple): The original size of the image in the format (width, height).
18 | possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...].
19 |
20 | Returns:
21 | tuple: The best fit resolution in the format (width, height).
22 | """
23 | original_width, original_height = original_size
24 | best_fit = None
25 | max_effective_resolution = 0
26 | min_wasted_resolution = float('inf')
27 |
28 | for width, height in possible_resolutions:
29 | scale = min(width / original_width, height / original_height)
30 | downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale)
31 | effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height)
32 | wasted_resolution = (width * height) - effective_resolution
33 |
34 | if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution):
35 | max_effective_resolution = effective_resolution
36 | min_wasted_resolution = wasted_resolution
37 | best_fit = (width, height)
38 |
39 | return best_fit
40 |
41 |
42 | def resize_and_pad_image(image, target_resolution):
43 | """
44 | Resize and pad an image to a target resolution while maintaining aspect ratio.
45 |
46 | Args:
47 | image (PIL.Image.Image): The input image.
48 | target_resolution (tuple): The target resolution (width, height) of the image.
49 |
50 | Returns:
51 | PIL.Image.Image: The resized and padded image.
52 | """
53 | original_width, original_height = image.size
54 | target_width, target_height = target_resolution
55 |
56 | scale_w = target_width / original_width
57 | scale_h = target_height / original_height
58 |
59 | if scale_w < scale_h:
60 | new_width = target_width
61 | new_height = min(math.ceil(original_height * scale_w), target_height)
62 | else:
63 | new_height = target_height
64 | new_width = min(math.ceil(original_width * scale_h), target_width)
65 |
66 | # Resize the image
67 | resized_image = image.resize((new_width, new_height))
68 |
69 | new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0))
70 | paste_x = (target_width - new_width) // 2
71 | paste_y = (target_height - new_height) // 2
72 | new_image.paste(resized_image, (paste_x, paste_y))
73 |
74 | return new_image
75 |
76 |
77 | def divide_to_patches(image, patch_size):
78 | """
79 | Divides an image into patches of a specified size.
80 |
81 | Args:
82 | image (PIL.Image.Image): The input image.
83 | patch_size (int): The size of each patch.
84 |
85 | Returns:
86 | list: A list of PIL.Image.Image objects representing the patches.
87 | """
88 | patches = []
89 | width, height = image.size
90 | for i in range(0, height, patch_size):
91 | for j in range(0, width, patch_size):
92 | box = (j, i, j + patch_size, i + patch_size)
93 | patch = image.crop(box)
94 | patches.append(patch)
95 |
96 | return patches
97 |
98 |
99 | def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size):
100 | """
101 | Calculate the shape of the image patch grid after the preprocessing for images of any resolution.
102 |
103 | Args:
104 | image_size (tuple): The size of the input image in the format (width, height).
105 | grid_pinpoints (str): A string representation of a list of possible resolutions.
106 | patch_size (int): The size of each image patch.
107 |
108 | Returns:
109 | tuple: The shape of the image patch grid in the format (width, height).
110 | """
111 | if type(grid_pinpoints) is list:
112 | possible_resolutions = grid_pinpoints
113 | else:
114 | possible_resolutions = ast.literal_eval(grid_pinpoints)
115 | width, height = select_best_resolution(image_size, possible_resolutions)
116 | return width // patch_size, height // patch_size
117 |
118 |
119 | def process_anyres_image(image, processor, grid_pinpoints):
120 | """
121 | Process an image with variable resolutions.
122 |
123 | Args:
124 | image (PIL.Image.Image): The input image to be processed.
125 | processor: The image processor object.
126 | grid_pinpoints (str): A string representation of a list of possible resolutions.
127 |
128 | Returns:
129 | torch.Tensor: A tensor containing the processed image patches.
130 | """
131 | if type(grid_pinpoints) is list:
132 | possible_resolutions = grid_pinpoints
133 | else:
134 | possible_resolutions = ast.literal_eval(grid_pinpoints)
135 | best_resolution = select_best_resolution(image.size, possible_resolutions)
136 | image_padded = resize_and_pad_image(image, best_resolution)
137 |
138 | patches = divide_to_patches(image_padded, processor.crop_size['height'])
139 |
140 | image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge']))
141 |
142 | image_patches = [image_original_resize] + patches
143 | image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0]
144 | for image_patch in image_patches]
145 | return torch.stack(image_patches, dim=0)
146 |
147 |
148 | def load_image_from_base64(image):
149 | return Image.open(BytesIO(base64.b64decode(image)))
150 |
151 |
152 | def expand2square(pil_img, background_color):
153 | width, height = pil_img.size
154 | if width == height:
155 | return pil_img
156 | elif width > height:
157 | result = Image.new(pil_img.mode, (width, width), background_color)
158 | result.paste(pil_img, (0, (width - height) // 2))
159 | return result
160 | else:
161 | result = Image.new(pil_img.mode, (height, height), background_color)
162 | result.paste(pil_img, ((height - width) // 2, 0))
163 | return result
164 |
165 |
166 | def process_images(images, image_processor, model_cfg):
167 | image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
168 | new_images = []
169 | if image_aspect_ratio == 'pad':
170 | for image in images:
171 | image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
172 | image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
173 | new_images.append(image)
174 | elif image_aspect_ratio == "anyres":
175 | for image in images:
176 | image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints)
177 | new_images.append(image)
178 | else:
179 | return image_processor(images, return_tensors='pt')['pixel_values']
180 | if all(x.shape == new_images[0].shape for x in new_images):
181 | new_images = torch.stack(new_images, dim=0)
182 | return new_images
183 |
184 |
185 | def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
186 | prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')]
187 |
188 | def insert_separator(X, sep):
189 | return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
190 |
191 | input_ids = []
192 | offset = 0
193 | if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
194 | offset = 1
195 | input_ids.append(prompt_chunks[0][0])
196 |
197 | for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
198 | input_ids.extend(x[offset:])
199 |
200 | if return_tensors is not None:
201 | if return_tensors == 'pt':
202 | return torch.tensor(input_ids, dtype=torch.long)
203 | raise ValueError(f'Unsupported tensor type: {return_tensors}')
204 | return input_ids
205 |
206 |
207 | def tokenizer_multiple_token(prompt, tokenizer, target_token_indenx=TARGET_TOKEN_INDEX, return_tensors=None):
208 |
209 | input_ids = []
210 | target_chunks = prompt.split("")
211 | for target_idx, target_ck in enumerate(target_chunks):
212 | _inputs = tokenizer_image_token(target_ck, tokenizer, IMAGE_TOKEN_INDEX, return_tensors=return_tensors)
213 | input_ids.extend(_inputs)
214 | if target_idx < len(target_ck) - 1:
215 | input_ids.append(TARGET_TOKEN_INDEX)
216 |
217 | if return_tensors is not None:
218 | if return_tensors == "pt":
219 | return torch.tensor(input_ids, dtype=torch.long)
220 | raise ValueError(f"Unsupported tensor type: {return_tensors}")
221 |
222 | return input_ids
223 |
224 | def get_model_name_from_path(model_path):
225 | model_path = model_path.strip("/")
226 | model_paths = model_path.split("/")
227 | if model_paths[-1].startswith('checkpoint-'):
228 | return model_paths[-2] + "_" + model_paths[-1]
229 | else:
230 | return model_paths[-1]
231 |
232 | class KeywordsStoppingCriteria(StoppingCriteria):
233 | def __init__(self, keywords, tokenizer, input_ids):
234 | self.keywords = keywords
235 | self.keyword_ids = []
236 | self.max_keyword_len = 0
237 | for keyword in keywords:
238 | cur_keyword_ids = tokenizer(keyword).input_ids
239 | if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
240 | cur_keyword_ids = cur_keyword_ids[1:]
241 | if len(cur_keyword_ids) > self.max_keyword_len:
242 | self.max_keyword_len = len(cur_keyword_ids)
243 | self.keyword_ids.append(torch.tensor(cur_keyword_ids))
244 | self.tokenizer = tokenizer
245 | self.start_len = input_ids.shape[1]
246 |
247 | def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
248 | offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
249 | self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
250 | for keyword_id in self.keyword_ids:
251 | truncated_output_ids = output_ids[0, -keyword_id.shape[0]:]
252 | if torch.equal(truncated_output_ids, keyword_id):
253 | return True
254 | outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
255 | for keyword in self.keywords:
256 | if keyword in outputs:
257 | return True
258 | return False
259 |
260 | def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
261 | outputs = []
262 | for i in range(output_ids.shape[0]):
263 | outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
264 | return all(outputs)
265 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from .language_model.setokim_llama import SetokimLlamaForCausalLM, SetokimConfig
3 | from .setok.model import SeTok
4 | except:
5 | pass
6 |
--------------------------------------------------------------------------------
/src/model/apply_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m fastchat.model.apply_delta --base ~/model_weights/llama-7b --target ~/model_weights/vicuna-7b --delta lmsys/vicuna-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from src import SetokimLlamaForCausalLM
11 |
12 |
13 | def apply_delta(base_model_path, target_model_path, delta_path):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading delta")
19 | delta = SetokimLlamaForCausalLM.from_pretrained(delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
20 | delta_tokenizer = AutoTokenizer.from_pretrained(delta_path)
21 |
22 | print("Applying delta")
23 | for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data += base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], \
31 | f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
32 | bparam = base.state_dict()[name]
33 | param.data[:bparam.shape[0], :bparam.shape[1]] += bparam
34 |
35 | print("Saving target model")
36 | delta.save_pretrained(target_model_path)
37 | delta_tokenizer.save_pretrained(target_model_path)
38 |
39 |
40 | if __name__ == "__main__":
41 | parser = argparse.ArgumentParser()
42 | parser.add_argument("--base-model-path", type=str, required=True)
43 | parser.add_argument("--target-model-path", type=str, required=True)
44 | parser.add_argument("--delta-path", type=str, required=True)
45 |
46 | args = parser.parse_args()
47 |
48 | apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
49 |
--------------------------------------------------------------------------------
/src/model/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 Haotian Liu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 | import warnings
18 | import shutil
19 |
20 | from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, BitsAndBytesConfig
21 | import torch
22 | from src.model import *
23 | from src.constants import DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
24 |
25 |
26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs):
27 | kwargs = {"device_map": device_map, **kwargs}
28 |
29 | if device != "cuda":
30 | kwargs['device_map'] = {"": device}
31 |
32 | if load_8bit:
33 | kwargs['load_in_8bit'] = True
34 | elif load_4bit:
35 | kwargs['load_in_4bit'] = True
36 | kwargs['quantization_config'] = BitsAndBytesConfig(
37 | load_in_4bit=True,
38 | bnb_4bit_compute_dtype=torch.float16,
39 | bnb_4bit_use_double_quant=True,
40 | bnb_4bit_quant_type='nf4'
41 | )
42 | else:
43 | kwargs['torch_dtype'] = torch.float16
44 |
45 | if use_flash_attn:
46 | kwargs['attn_implementation'] = 'flash_attention_2'
47 |
48 | if 'setokim' in model_name.lower():
49 | # Load Setokim model
50 | if 'lora' in model_name.lower() and model_base is None:
51 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/Setokim#launch-a-model-worker-lora-weights-unmerged.')
52 | if 'lora' in model_name.lower() and model_base is not None:
53 | from src.model.language_model.setokim_llama import SetokimConfig
54 | lora_cfg_pretrained = SetokimConfig.from_pretrained(model_path)
55 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
56 | print('Loading Setokim from base model...')
57 | model = SetokimLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs)
58 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
59 | if model.lm_head.weight.shape[0] != token_num:
60 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
61 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype))
62 |
63 | print('Loading additional Setokim weights...')
64 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')):
65 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu')
66 | else:
67 | # this is probably from HF Hub
68 | from huggingface_hub import hf_hub_download
69 | def load_from_hf(repo_id, filename, subfolder=None):
70 | cache_file = hf_hub_download(
71 | repo_id=repo_id,
72 | filename=filename,
73 | subfolder=subfolder)
74 | return torch.load(cache_file, map_location='cpu')
75 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin')
76 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()}
77 | if any(k.startswith('model.model.') for k in non_lora_trainables):
78 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()}
79 | model.load_state_dict(non_lora_trainables, strict=False)
80 |
81 | from peft import PeftModel
82 | print('Loading LoRA weights...')
83 | model = PeftModel.from_pretrained(model, model_path)
84 | print('Merging LoRA weights...')
85 | model = model.merge_and_unload()
86 | print('Model is loaded...')
87 | elif model_base is not None:
88 | # this may be mm projector only
89 | print('Loading Setokim from base model...')
90 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
91 | cfg_pretrained = AutoConfig.from_pretrained(model_path)
92 | model = SetokimLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
93 |
94 | mm_projector_weights = torch.load(os.path.join(model_path, 'mm_projector.bin'), map_location='cpu')
95 | mm_projector_weights = {k: v.to(torch.float16) for k, v in mm_projector_weights.items()}
96 | model.load_state_dict(mm_projector_weights, strict=False)
97 | else:
98 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
99 | model = SetokimLlamaForCausalLM.from_pretrained(
100 | model_path,
101 | low_cpu_mem_usage=True,
102 | **kwargs
103 | )
104 | else:
105 | # Load language model
106 | if model_base is not None:
107 | # PEFT model
108 | from peft import PeftModel
109 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False)
110 | model = AutoModelForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, **kwargs)
111 | print(f"Loading LoRA weights from {model_path}")
112 | model = PeftModel.from_pretrained(model, model_path)
113 | print(f"Merging weights")
114 | model = model.merge_and_unload()
115 | print('Convert to FP16...')
116 | model.to(torch.float16)
117 | else:
118 | use_fast = False
119 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
120 | model = AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
121 |
122 | image_processor = None
123 |
124 | if 'Setokim' in model_name.lower():
125 | mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
126 | mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
127 | if mm_use_im_patch_token:
128 | tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
129 | if mm_use_im_start_end:
130 | tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
131 | model.resize_token_embeddings(len(tokenizer))
132 |
133 | vision_tower = model.get_vision_tower()
134 | if not vision_tower.is_loaded:
135 | vision_tower.load_model(device_map=device_map)
136 | if device_map != 'auto':
137 | vision_tower.to(device=device_map, dtype=torch.float16)
138 | image_processor = vision_tower.image_processor
139 |
140 | if hasattr(model.config, "max_sequence_length"):
141 | context_len = model.config.max_sequence_length
142 | else:
143 | context_len = 2048
144 |
145 | return tokenizer, model, image_processor, context_len
146 |
--------------------------------------------------------------------------------
/src/model/consolidate.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.consolidate --src ~/model_weights/llava-7b --dst ~/model_weights/llava-7b_consolidate
4 | """
5 | import argparse
6 |
7 | import torch
8 | from transformers import AutoTokenizer, AutoModelForCausalLM
9 | from src.model import *
10 | from src.model.utils import auto_upgrade
11 |
12 |
13 | def consolidate_ckpt(src_path, dst_path):
14 | print("Loading model")
15 | auto_upgrade(src_path)
16 | src_model = AutoModelForCausalLM.from_pretrained(src_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 | src_tokenizer = AutoTokenizer.from_pretrained(src_path, use_fast=False)
18 | src_model.save_pretrained(dst_path)
19 | src_tokenizer.save_pretrained(dst_path)
20 |
21 |
22 | if __name__ == "__main__":
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument("--src", type=str, required=True)
25 | parser.add_argument("--dst", type=str, required=True)
26 |
27 | args = parser.parse_args()
28 |
29 | consolidate_ckpt(args.src, args.dst)
30 |
--------------------------------------------------------------------------------
/src/model/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Adopted from DiT, which is modified from OpenAI's diffusion repos
2 | # DiT: https://github.com/facebookresearch/DiT/diffusion
3 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
4 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
5 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
6 |
7 | from . import gaussian_diffusion as gd
8 | from .respace import SpacedDiffusion, space_timesteps
9 |
10 |
11 | def create_diffusion(
12 | timestep_respacing,
13 | noise_schedule="linear",
14 | use_kl=False,
15 | sigma_small=False,
16 | predict_xstart=False,
17 | learn_sigma=True,
18 | rescale_learned_sigmas=False,
19 | diffusion_steps=1000
20 | ):
21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22 | if use_kl:
23 | loss_type = gd.LossType.RESCALED_KL
24 | elif rescale_learned_sigmas:
25 | loss_type = gd.LossType.RESCALED_MSE
26 | else:
27 | loss_type = gd.LossType.MSE
28 | if timestep_respacing is None or timestep_respacing == "":
29 | timestep_respacing = [diffusion_steps]
30 | return SpacedDiffusion(
31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32 | betas=betas,
33 | model_mean_type=(
34 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
35 | ),
36 | model_var_type=(
37 | (
38 | gd.ModelVarType.FIXED_LARGE
39 | if not sigma_small
40 | else gd.ModelVarType.FIXED_SMALL
41 | )
42 | if not learn_sigma
43 | else gd.ModelVarType.LEARNED_RANGE
44 | ),
45 | loss_type=loss_type
46 | # rescale_timesteps=rescale_timesteps,
47 | )
48 |
--------------------------------------------------------------------------------
/src/model/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a Gaussian distribution discretizing to a
50 | given image.
51 | :param x: the target images. It is assumed that this was uint8 values,
52 | rescaled to the range [-1, 1].
53 | :param means: the Gaussian mean Tensor.
54 | :param log_scales: the Gaussian log stddev Tensor.
55 | :return: a tensor like x of log probabilities (in nats).
56 | """
57 | assert x.shape == means.shape == log_scales.shape
58 | centered_x = x - means
59 | inv_stdv = th.exp(-log_scales)
60 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
61 | cdf_plus = approx_standard_normal_cdf(plus_in)
62 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
63 | cdf_min = approx_standard_normal_cdf(min_in)
64 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
65 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
66 | cdf_delta = cdf_plus - cdf_min
67 | log_probs = th.where(
68 | x < -0.999,
69 | log_cdf_plus,
70 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
71 | )
72 | assert log_probs.shape == x.shape
73 | return log_probs
74 |
--------------------------------------------------------------------------------
/src/model/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | if isinstance(model, _WrappedModel):
107 | return model
108 | return _WrappedModel(
109 | model, self.timestep_map, self.original_num_steps
110 | )
111 |
112 | def _scale_timesteps(self, t):
113 | # Scaling is done by the wrapped model.
114 | return t
115 |
116 |
117 | class _WrappedModel:
118 | def __init__(self, model, timestep_map, original_num_steps):
119 | self.model = model
120 | self.timestep_map = timestep_map
121 | # self.rescale_timesteps = rescale_timesteps
122 | self.original_num_steps = original_num_steps
123 |
124 | def __call__(self, x, ts, **kwargs):
125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126 | new_ts = map_tensor[ts]
127 | # if self.rescale_timesteps:
128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129 | return self.model(x, new_ts, **kwargs)
130 |
--------------------------------------------------------------------------------
/src/model/loss/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from .diffloss import DiffLoss
3 | from .mse import WeightedMSELoss
4 | from .perceptual import LPIPS
5 | from .discriminator import GANLoss
6 | from .multilabel_constrastive import MultilabelContrastiveLoss
--------------------------------------------------------------------------------
/src/model/loss/diffloss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.utils.checkpoint import checkpoint
4 | import math
5 |
6 | from ..diffusion import create_diffusion
7 |
8 |
9 | class DiffLoss(nn.Module):
10 | """Diffusion Loss"""
11 | def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False):
12 | super(DiffLoss, self).__init__()
13 | self.in_channels = target_channels
14 | self.net = SimpleMLPAdaLN(
15 | in_channels=target_channels,
16 | model_channels=width,
17 | out_channels=target_channels * 2, # for vlb loss
18 | z_channels=z_channels,
19 | num_res_blocks=depth,
20 | grad_checkpointing=grad_checkpointing
21 | )
22 |
23 | self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine")
24 | self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine")
25 |
26 | def forward(self, target, z, mask=None):
27 | t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
28 | model_kwargs = dict(c=z)
29 | loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
30 | loss = loss_dict["loss"]
31 | if mask is not None:
32 | loss = (loss * mask).sum() / mask.sum()
33 | return loss.mean()
34 |
35 | def sample(self, z, temperature=1.0, cfg=1.0):
36 | # diffusion loss sampling
37 | if not cfg == 1.0:
38 | noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda()
39 | noise = torch.cat([noise, noise], dim=0)
40 | model_kwargs = dict(c=z, cfg_scale=cfg)
41 | sample_fn = self.net.forward_with_cfg
42 | else:
43 | noise = torch.randn(z.shape[0], self.in_channels).cuda()
44 | model_kwargs = dict(c=z)
45 | sample_fn = self.net.forward
46 |
47 | sampled_token_latent = self.gen_diffusion.p_sample_loop(
48 | sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
49 | temperature=temperature
50 | )
51 |
52 | return sampled_token_latent
53 |
54 |
55 | def modulate(x, shift, scale):
56 | return x * (1 + scale) + shift
57 |
58 |
59 | class TimestepEmbedder(nn.Module):
60 | """
61 | Embeds scalar timesteps into vector representations.
62 | """
63 | def __init__(self, hidden_size, frequency_embedding_size=256):
64 | super().__init__()
65 | self.mlp = nn.Sequential(
66 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
67 | nn.SiLU(),
68 | nn.Linear(hidden_size, hidden_size, bias=True),
69 | )
70 | self.frequency_embedding_size = frequency_embedding_size
71 |
72 | @staticmethod
73 | def timestep_embedding(t, dim, max_period=10000):
74 | """
75 | Create sinusoidal timestep embeddings.
76 | :param t: a 1-D Tensor of N indices, one per batch element.
77 | These may be fractional.
78 | :param dim: the dimension of the output.
79 | :param max_period: controls the minimum frequency of the embeddings.
80 | :return: an (N, D) Tensor of positional embeddings.
81 | """
82 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
83 | half = dim // 2
84 | freqs = torch.exp(
85 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
86 | ).to(device=t.device)
87 | args = t[:, None].float() * freqs[None]
88 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
89 | if dim % 2:
90 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
91 | return embedding
92 |
93 | def forward(self, t):
94 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
95 | t_emb = self.mlp(t_freq)
96 | return t_emb
97 |
98 |
99 | class ResBlock(nn.Module):
100 | """
101 | A residual block that can optionally change the number of channels.
102 | :param channels: the number of input channels.
103 | """
104 |
105 | def __init__(
106 | self,
107 | channels
108 | ):
109 | super().__init__()
110 | self.channels = channels
111 |
112 | self.in_ln = nn.LayerNorm(channels, eps=1e-6)
113 | self.mlp = nn.Sequential(
114 | nn.Linear(channels, channels, bias=True),
115 | nn.SiLU(),
116 | nn.Linear(channels, channels, bias=True),
117 | )
118 |
119 | self.adaLN_modulation = nn.Sequential(
120 | nn.SiLU(),
121 | nn.Linear(channels, 3 * channels, bias=True)
122 | )
123 |
124 | def forward(self, x, y):
125 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
126 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
127 | h = self.mlp(h)
128 | return x + gate_mlp * h
129 |
130 |
131 | class FinalLayer(nn.Module):
132 | """
133 | The final layer adopted from DiT.
134 | """
135 | def __init__(self, model_channels, out_channels):
136 | super().__init__()
137 | self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
138 | self.linear = nn.Linear(model_channels, out_channels, bias=True)
139 | self.adaLN_modulation = nn.Sequential(
140 | nn.SiLU(),
141 | nn.Linear(model_channels, 2 * model_channels, bias=True)
142 | )
143 |
144 | def forward(self, x, c):
145 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
146 | x = modulate(self.norm_final(x), shift, scale)
147 | x = self.linear(x)
148 | return x
149 |
150 |
151 | class SimpleMLPAdaLN(nn.Module):
152 | """
153 | The MLP for Diffusion Loss.
154 | :param in_channels: channels in the input Tensor.
155 | :param model_channels: base channel count for the model.
156 | :param out_channels: channels in the output Tensor.
157 | :param z_channels: channels in the condition.
158 | :param num_res_blocks: number of residual blocks per downsample.
159 | """
160 |
161 | def __init__(
162 | self,
163 | in_channels,
164 | model_channels,
165 | out_channels,
166 | z_channels,
167 | num_res_blocks,
168 | grad_checkpointing=False
169 | ):
170 | super().__init__()
171 |
172 | self.in_channels = in_channels
173 | self.model_channels = model_channels
174 | self.out_channels = out_channels
175 | self.num_res_blocks = num_res_blocks
176 | self.grad_checkpointing = grad_checkpointing
177 |
178 | self.time_embed = TimestepEmbedder(model_channels)
179 | self.cond_embed = nn.Linear(z_channels, model_channels)
180 |
181 | self.input_proj = nn.Linear(in_channels, model_channels)
182 |
183 | res_blocks = []
184 | for i in range(num_res_blocks):
185 | res_blocks.append(ResBlock(
186 | model_channels,
187 | ))
188 |
189 | self.res_blocks = nn.ModuleList(res_blocks)
190 | self.final_layer = FinalLayer(model_channels, out_channels)
191 |
192 | self.initialize_weights()
193 |
194 | def initialize_weights(self):
195 | def _basic_init(module):
196 | if isinstance(module, nn.Linear):
197 | torch.nn.init.xavier_uniform_(module.weight)
198 | if module.bias is not None:
199 | nn.init.constant_(module.bias, 0)
200 | self.apply(_basic_init)
201 |
202 | # Initialize timestep embedding MLP
203 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
204 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
205 |
206 | # Zero-out adaLN modulation layers
207 | for block in self.res_blocks:
208 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
209 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
210 |
211 | # Zero-out output layers
212 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
214 | nn.init.constant_(self.final_layer.linear.weight, 0)
215 | nn.init.constant_(self.final_layer.linear.bias, 0)
216 |
217 | def forward(self, x, t, c):
218 | """
219 | Apply the model to an input batch.
220 | :param x: an [N x C] Tensor of inputs.
221 | :param t: a 1-D batch of timesteps.
222 | :param c: conditioning from AR transformer.
223 | :return: an [N x C] Tensor of outputs.
224 | """
225 | x = self.input_proj(x)
226 | t = self.time_embed(t)
227 | c = self.cond_embed(c)
228 |
229 | y = t + c
230 |
231 | if self.grad_checkpointing and not torch.jit.is_scripting():
232 | for block in self.res_blocks:
233 | x = checkpoint(block, x, y)
234 | else:
235 | for block in self.res_blocks:
236 | x = block(x, y)
237 |
238 | return self.final_layer(x, y)
239 |
240 | def forward_with_cfg(self, x, t, c, cfg_scale):
241 | half = x[: len(x) // 2]
242 | combined = torch.cat([half, half], dim=0)
243 | model_out = self.forward(combined, t, c)
244 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
245 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
246 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
247 | eps = torch.cat([half_eps, half_eps], dim=0)
248 | return torch.cat([eps, rest], dim=1)
--------------------------------------------------------------------------------
/src/model/loss/mse.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | class WeightedMSELoss(nn.Module):
4 | def __init__(self, weight=1.0):
5 | super().__init__()
6 | self.mse = nn.MSELoss(reduction='none')
7 | self.weight = weight
8 |
9 | def forward(self, pred, target, loss_mask=None, loss_weight=None):
10 | mse_loss = self.mse(pred, target)
11 | if loss_mask is not None:
12 | mse_loss = (mse_loss * loss_mask)
13 | weight = loss_mask.sum([-2, -1])
14 | weight += 1
15 | mse_loss = mse_loss.sum([-2, -1])
16 | mse_loss = (mse_loss / weight)
17 | else:
18 | mse_loss = mse_loss.mean([-3, -2, -1])
19 | return mse_loss.mean() * self.weight
--------------------------------------------------------------------------------
/src/model/loss/multilabel_constrastive.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from einops import rearrange, repeat
6 | import torch.distributed as dist
7 | import numpy as np
8 | from typing import Optional, Dict, List
9 | from timm.loss import SoftTargetCrossEntropy
10 | import diffdist.functional as diff_dist
11 |
12 |
13 |
14 | def dist_collect(x):
15 | """ collect all tensor from all GPUs
16 | args:
17 | x: shape (mini_batch, ...)
18 | returns:
19 | shape (mini_batch * num_gpu, ...)
20 | """
21 | x = x.contiguous()
22 | out_list = [torch.zeros_like(x, device=x.device, dtype=x.dtype).contiguous() for _ in range(dist.get_world_size())]
23 | out_list = diff_dist.all_gather(out_list, x)
24 | return torch.cat(out_list, dim=0).contiguous()
25 |
26 |
27 | class MultilabelContrastiveLoss(nn.Module):
28 | def __init__(self,
29 | text_encoder: nn.Module,
30 | contrast_temperature: Optional[float]=0.07,
31 | multi_label: Optional[int]=0,
32 | share_temperature: Optional[bool]=False,
33 | multi_label_loss_weight: Optional[float]=1.0,
34 | **kwargs) -> None:
35 | super().__init__(MultilabelContrastiveLoss)
36 |
37 | self.text_encoder = text_encoder
38 | self.contrast_temperature = contrast_temperature
39 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
40 | self.cross_entropy = nn.CrossEntropyLoss()
41 | self.soft_cross_entropy = SoftTargetCrossEntropy()
42 |
43 | self.multi_label = multi_label
44 | self.share_temperature = share_temperature
45 | if self.with_multi_label and not self.share_temperature:
46 | self.multi_label_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / contrast_temperature))
47 | self.multi_label_loss_weight = multi_label_loss_weight
48 |
49 | @property
50 | def with_multi_label(self):
51 | return self.multi_label > 0
52 |
53 |
54 | def loss(self, image_x, text_x):
55 |
56 | batch_size = image_x.shape[0]
57 | # get label globally
58 | labels = torch.arange(batch_size, dtype=torch.long, device=image_x.device) + batch_size * dist.get_rank()
59 |
60 | # [B, C]
61 | image_x = F.normalize(image_x, dim=-1)
62 | text_x = F.normalize(text_x, dim=-1)
63 |
64 | logits_per_img = image_x @ dist_collect(text_x).t()
65 | logits_per_text = text_x @ dist_collect(image_x).t()
66 |
67 | logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
68 | loss_img = self.cross_entropy(logits_per_img * logit_scale, labels)
69 | loss_text = self.cross_entropy(logits_per_text * logit_scale, labels)
70 |
71 | loss = 0.5 * (loss_img + loss_text)
72 |
73 | return loss
74 |
75 | def multi_label_loss(self, image_feat, text_feat):
76 | """
77 |
78 | Args:
79 | image_feat (torch.Tensor): shape [B, L1, C]
80 | text_feat (torch.Tensor): shape [B, L2, C]
81 |
82 | Returns:
83 |
84 | """
85 | # [B, L1, C], L1 = 1
86 | image_feat = F.normalize(image_feat, dim=-1)
87 | # [B, L2, C]
88 | text_feat = F.normalize(text_feat, dim=-1)
89 |
90 | # [B, L1, L2]
91 | dist_per_img = image_feat @ rearrange(text_feat, 'b l c -> b c l')
92 | # [B, L2, L1]
93 | dist_per_text = text_feat @ rearrange(image_feat, 'b l c -> b c l')
94 |
95 | if self.share_temperature:
96 | logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
97 | else:
98 | logit_scale = torch.clamp(self.multi_label_logit_scale.exp(), max=100)
99 |
100 | batch = image_feat.shape[0]
101 | img_len = image_feat.shape[1]
102 | text_len = text_feat.shape[1]
103 | # [B, L1, L2]
104 | pos_labels_batch_img = rearrange(torch.ones_like(dist_per_text) / dist_per_text.size(1), 'b l2 l1 -> b l1 l2')
105 | # [B, L2, L1]
106 | pos_labels_batch_text = rearrange(torch.ones_like(dist_per_img) / dist_per_img.size(1), 'b l1 l2 -> b l2 l1')
107 |
108 | image_x = rearrange(image_feat, 'b l c -> (b l) c')
109 | text_x = rearrange(text_feat, 'b l c -> (b l) c')
110 |
111 | logits_per_img = image_x @ dist_collect(text_x).t()
112 | logits_per_text = text_x @ dist_collect(image_x).t()
113 |
114 | # get label globally
115 | # [B, L1, B, L2, W]
116 | labels_per_img = F.one_hot(
117 | torch.ones(batch, img_len, batch, text_len, dtype=torch.long, device=image_x.device) * dist.get_rank(),
118 | num_classes=dist.get_world_size()).to(image_x.dtype)
119 | labels_per_img *= rearrange(pos_labels_batch_img, 'b l1 l2 -> b l1 1 l2 1') * repeat(
120 | torch.eye(batch, dtype=image_x.dtype, device=image_x.device), 'b1 b2 -> b1 1 b2 1 1')
121 | # [BxL1, WxBxL2]
122 | labels_per_img = rearrange(labels_per_img, 'b1 l1 b2 l2 w -> (b1 l1) (w b2 l2)')
123 | # [B, L2, B, L1, W]
124 | labels_per_text = F.one_hot(
125 | torch.ones(batch, text_len, batch, img_len, dtype=torch.long, device=text_x.device) * dist.get_rank(),
126 | num_classes=dist.get_world_size()).to(text_x.dtype)
127 | labels_per_text *= rearrange(pos_labels_batch_text, 'b l2 l1 -> b l2 1 l1 1') * repeat(
128 | torch.eye(batch, dtype=text_x.dtype, device=image_x.device), 'b2 b1 -> b2 1 b1 1 1')
129 | # [BxL2, WxBxL1]
130 | labels_per_text = rearrange(labels_per_text, 'b2 l2 b1 l1 w -> (b2 l2) (w b1 l1)')
131 |
132 | loss_img = self.soft_cross_entropy(logits_per_img * logit_scale, labels_per_img)
133 | loss_text = self.soft_cross_entropy(logits_per_text * logit_scale, labels_per_text)
134 |
135 | loss = 0.5 * (loss_img + loss_text)
136 |
137 | return loss
138 |
139 |
140 | def forward(self, image_x, text_x):
141 | losses = self.loss(image_x, text_x)
142 | text_outs = self.text_encoder(text_x)
143 | losses_dict = dict(loss=losses.detach().item())
144 | if self.with_multi_label:
145 | image_multi_label_x = image_x.unsqueeze(1)
146 | text_multi_label_x = text_outs.unsqueeze(1)
147 | multi_label_loss = self.multi_label_loss(image_multi_label_x,
148 | text_multi_label_x) * self.multi_label_loss_weight
149 | losses += multi_label_loss
150 | losses_dict.update({
151 | "multi_label_loss": multi_label_loss.detach().item()
152 | })
153 | return losses, losses_dict
154 |
155 | return losses, losses_dict
--------------------------------------------------------------------------------
/src/model/loss/perceptual.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torchvision import models
6 | from collections import namedtuple
7 | import os, hashlib
8 | import requests
9 | from tqdm import tqdm
10 | # from taming.util import get_ckpt_path
11 |
12 | URL_MAP = {
13 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
14 | }
15 |
16 | CKPT_MAP = {
17 | "vgg_lpips": "vgg.pth"
18 | }
19 |
20 | MD5_MAP = {
21 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
22 | }
23 |
24 |
25 | def download(url, local_path, chunk_size=1024):
26 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
27 | with requests.get(url, stream=True) as r:
28 | total_size = int(r.headers.get("content-length", 0))
29 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
30 | with open(local_path, "wb") as f:
31 | for data in r.iter_content(chunk_size=chunk_size):
32 | if data:
33 | f.write(data)
34 | pbar.update(chunk_size)
35 |
36 |
37 | def md5_hash(path):
38 | with open(path, "rb") as f:
39 | content = f.read()
40 | return hashlib.md5(content).hexdigest()
41 |
42 |
43 |
44 | def get_ckpt_path(name, root, check=False):
45 | assert name in URL_MAP
46 | path = os.path.join(root, CKPT_MAP[name])
47 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
48 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
49 | download(URL_MAP[name], path)
50 | md5 = md5_hash(path)
51 | assert md5 == MD5_MAP[name], md5
52 | return path
53 |
54 |
55 | class LPIPS(nn.Module):
56 | # Learned perceptual metric
57 | def __init__(self, use_dropout=True):
58 | super().__init__()
59 | self.scaling_layer = ScalingLayer()
60 | self.chns = [64, 128, 256, 512, 512] # vg16 features
61 | self.net = vgg16(pretrained=True, requires_grad=False)
62 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
63 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
64 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
65 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
66 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
67 | self.load_from_pretrained()
68 | for param in self.parameters():
69 | param.requires_grad = False
70 |
71 | def load_from_pretrained(self, name="vgg_lpips"):
72 | ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
73 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
74 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
75 |
76 | @classmethod
77 | def from_pretrained(cls, name="vgg_lpips"):
78 | if name != "vgg_lpips":
79 | raise NotImplementedError
80 | model = cls()
81 | ckpt = get_ckpt_path(name)
82 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
83 | return model
84 |
85 | def forward(self, input, target):
86 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
87 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
88 | feats0, feats1, diffs = {}, {}, {}
89 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
90 | for kk in range(len(self.chns)):
91 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
92 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
93 |
94 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
95 | val = res[0]
96 | for l in range(1, len(self.chns)):
97 | val += res[l]
98 | return val
99 |
100 |
101 | class ScalingLayer(nn.Module):
102 | def __init__(self):
103 | super(ScalingLayer, self).__init__()
104 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
105 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
106 |
107 | def forward(self, inp):
108 | return (inp - self.shift) / self.scale
109 |
110 |
111 | class NetLinLayer(nn.Module):
112 | """ A single linear layer which does a 1x1 conv """
113 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
114 | super(NetLinLayer, self).__init__()
115 | layers = [nn.Dropout(), ] if (use_dropout) else []
116 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
117 | self.model = nn.Sequential(*layers)
118 |
119 |
120 | class vgg16(torch.nn.Module):
121 | def __init__(self, requires_grad=False, pretrained=True):
122 | super(vgg16, self).__init__()
123 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
124 | self.slice1 = torch.nn.Sequential()
125 | self.slice2 = torch.nn.Sequential()
126 | self.slice3 = torch.nn.Sequential()
127 | self.slice4 = torch.nn.Sequential()
128 | self.slice5 = torch.nn.Sequential()
129 | self.N_slices = 5
130 | for x in range(4):
131 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
132 | for x in range(4, 9):
133 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
134 | for x in range(9, 16):
135 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
136 | for x in range(16, 23):
137 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
138 | for x in range(23, 30):
139 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
140 | if not requires_grad:
141 | for param in self.parameters():
142 | param.requires_grad = False
143 |
144 | def forward(self, X):
145 | h = self.slice1(X)
146 | h_relu1_2 = h
147 | h = self.slice2(h)
148 | h_relu2_2 = h
149 | h = self.slice3(h)
150 | h_relu3_3 = h
151 | h = self.slice4(h)
152 | h_relu4_3 = h
153 | h = self.slice5(h)
154 | h_relu5_3 = h
155 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
156 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
157 | return out
158 |
159 |
160 | def normalize_tensor(x,eps=1e-10):
161 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
162 | return x/(norm_factor+eps)
163 |
164 |
165 | def spatial_average(x, keepdim=True):
166 | return x.mean([2,3],keepdim=keepdim)
--------------------------------------------------------------------------------
/src/model/loss/segmentation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | ALPHA = 0.8
6 | GAMMA = 2
7 |
8 |
9 | class BCELoss(nn.Module):
10 | def forward(self, prediction, target):
11 | loss = F.binary_cross_entropy_with_logits(prediction,target)
12 | return loss, {}
13 |
14 |
15 | class BCELossWithQuant(nn.Module):
16 | def __init__(self, codebook_weight=1.):
17 | super().__init__()
18 | self.codebook_weight = codebook_weight
19 |
20 | def forward(self, qloss, target, prediction, split):
21 | bce_loss = F.binary_cross_entropy_with_logits(prediction,target)
22 | loss = bce_loss + self.codebook_weight*qloss
23 | return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(),
24 | "{}/bce_loss".format(split): bce_loss.detach().mean(),
25 | "{}/quant_loss".format(split): qloss.detach().mean()
26 | }
27 |
28 | """Stripped version of https://github.com/NExT-ChatV/NExT-Chat/blob/main/mllm/models/sam/sam_loss.py"""
29 |
30 | class FocalLoss(nn.Module):
31 |
32 | def __init__(self, weight=None, size_average=True):
33 | super().__init__()
34 |
35 | def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
36 | # inputs = F.sigmoid(inputs)
37 | # inputs = torch.clamp(inputs, min=0, max=1)
38 | #flatten label and prediction tensors
39 | inputs = inputs.view(-1)
40 | targets = targets.view(-1)
41 | BCE = F.binary_cross_entropy_with_logits(inputs, targets, reduction='mean')
42 | BCE_EXP = torch.exp(-BCE)
43 | focal_loss = alpha * (1 - BCE_EXP)**gamma * BCE
44 |
45 | return focal_loss
46 |
47 |
48 | class DiceLoss(nn.Module):
49 |
50 | def __init__(self, weight=None, size_average=True):
51 | super().__init__()
52 |
53 | def forward(self, inputs, targets, smooth=1):
54 | inputs = F.sigmoid(inputs)
55 | inputs = torch.clamp(inputs, min=0, max=1)
56 | #flatten label and prediction tensors
57 | inputs = inputs.view(-1)
58 | targets = targets.view(-1)
59 |
60 | intersection = (inputs * targets).sum()
61 | dice = (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)
62 |
63 | return 1 - dice
64 |
65 |
66 | def calc_iou(pred_mask: torch.Tensor, gt_mask: torch.Tensor):
67 | pred_mask = (pred_mask >= 0.5)
68 | intersection = torch.sum(torch.mul(pred_mask, gt_mask), dim=(1, 2))
69 | union = torch.sum(pred_mask, dim=(1, 2)) + torch.sum(gt_mask, dim=(1, 2)) - intersection
70 | epsilon = 1e-7
71 | batch_iou = intersection / (union + epsilon)
72 |
73 | # batch_iou = batch_iou.unsqueeze(1)
74 | return batch_iou
75 |
76 |
77 | class SamLoss(nn.Module):
78 | def __init__(self):
79 | super().__init__()
80 | self.focal_loss = FocalLoss()
81 | self.dice_loss = DiceLoss()
82 |
83 | def forward(self, pred_masks, gt_masks, iou_predictions, device):
84 | loss_focal = 0.
85 | loss_dice = 0.
86 | loss_iou = 0.
87 | num_masks = sum(len(pred_mask) for pred_mask in pred_masks)
88 | for pred_mask, gt_mask, iou_prediction in zip(pred_masks, gt_masks, iou_predictions):
89 | gt_mask = gt_mask.to(device)
90 | batch_iou = calc_iou(pred_mask, gt_mask)
91 | loss_focal += self.focal_loss(pred_mask, gt_mask, num_masks)
92 | loss_dice += self.dice_loss(pred_mask, gt_mask, num_masks)
93 | loss_iou += F.mse_loss(iou_prediction, batch_iou, reduction='sum') / num_masks
94 |
95 | loss_total = 20. * loss_focal + loss_dice + loss_iou
96 | return loss_total
--------------------------------------------------------------------------------
/src/model/make_delta.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python3 -m llava.model.make_delta --base ~/model_weights/llama-7b --target ~/model_weights/llava-7b --delta ~/model_weights/llava-7b-delta --hub-repo-id liuhaotian/llava-7b-delta
4 | """
5 | import argparse
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import AutoTokenizer, AutoModelForCausalLM
10 | from src.model.utils import auto_upgrade
11 |
12 |
13 | def make_delta(base_model_path, target_model_path, delta_path, hub_repo_id):
14 | print("Loading base model")
15 | base = AutoModelForCausalLM.from_pretrained(
16 | base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
17 |
18 | print("Loading target model")
19 | auto_upgrade(target_model_path)
20 | target = AutoModelForCausalLM.from_pretrained(target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True)
21 |
22 | print("Calculating delta")
23 | for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
24 | if name not in base.state_dict():
25 | assert name in ['model.mm_projector.weight', 'model.mm_projector.bias'], f'{name} not in base model'
26 | continue
27 | if param.data.shape == base.state_dict()[name].shape:
28 | param.data -= base.state_dict()[name]
29 | else:
30 | assert name in ['model.embed_tokens.weight', 'lm_head.weight'], f'{name} dimension mismatch: {param.data.shape} vs {base.state_dict()[name].shape}'
31 | bparam = base.state_dict()[name]
32 | param.data[:bparam.shape[0], :bparam.shape[1]] -= bparam
33 |
34 | print("Saving delta")
35 | if hub_repo_id:
36 | kwargs = {"push_to_hub": True, "repo_id": hub_repo_id}
37 | else:
38 | kwargs = {}
39 | target.save_pretrained(delta_path, **kwargs)
40 | target_tokenizer = AutoTokenizer.from_pretrained(target_model_path)
41 | target_tokenizer.save_pretrained(delta_path, **kwargs)
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--base-model-path", type=str, required=True)
47 | parser.add_argument("--target-model-path", type=str, required=True)
48 | parser.add_argument("--delta-path", type=str, required=True)
49 | parser.add_argument("--hub-repo-id", type=str, default=None)
50 | args = parser.parse_args()
51 |
52 | make_delta(args.base_model_path, args.target_model_path, args.delta_path, args.hub_repo_id)
53 |
--------------------------------------------------------------------------------
/src/model/multimodal_encoder/builder.py:
--------------------------------------------------------------------------------
1 | import os
2 | from dataclasses import asdict, is_dataclass
3 | from .clip_encoder import CLIPVisionTower
4 | from ..setok.tokenizer import SetokTokenizer
5 |
6 | def build_vision_tower(vision_tower_cfg, **kwargs):
7 | vision_tower = getattr(vision_tower_cfg, 'vision_tokenizer', getattr(vision_tower_cfg, 'vision_tower', None))
8 | # is_absolute_path_exists = os.path.exists(vision_tower)
9 | # if is_absolute_path_exists and (vision_tower.startswith("openai") or vision_tower.startswith("laion") or "ShareGPT4V" in vision_tower or "vit" in vision_tower):
10 | # return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
11 | if is_dataclass(vision_tower_cfg):
12 | vision_tower_cfg = asdict(vision_tower_cfg)
13 | elif isinstance(vision_tower_cfg, dict):
14 | vision_tower_cfg = vision_tower_cfg
15 | else:
16 | vision_tower_cfg = vars(vision_tower_cfg)
17 |
18 |
19 | if 'siglip' in vision_tower:
20 | return SetokTokenizer(**vision_tower_cfg, **kwargs)
21 |
22 | raise ValueError(f'Unknown vision tower: {vision_tower}')
23 |
--------------------------------------------------------------------------------
/src/model/multimodal_encoder/clip_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import torch
3 | import torch.nn as nn
4 | from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
5 |
6 |
7 | class CLIPVisionTower(nn.Module):
8 | def __init__(self, vision_tower: Optional[str],
9 | unfreeze_mm_vision_tower: Optional[str] = False,
10 | mm_vision_select_feature: Optional[str] = 'patch',
11 | mm_vision_select_layer: Optional[int] = -2,
12 | delay_load=False):
13 | super().__init__()
14 |
15 | self.is_loaded = False
16 |
17 | self.vision_tower_name = vision_tower
18 | self.select_layer = mm_vision_select_layer
19 | self.select_feature = mm_vision_select_feature
20 |
21 | if not delay_load:
22 | self.load_model()
23 | elif unfreeze_mm_vision_tower:
24 | self.load_model()
25 | else:
26 | self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
27 |
28 | def load_model(self, device_map=None):
29 | if self.is_loaded:
30 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
31 | return
32 |
33 | self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
34 | self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
35 | self.vision_tower.requires_grad_(False)
36 |
37 | self.is_loaded = True
38 |
39 | def feature_select(self, image_forward_outs):
40 | image_features = image_forward_outs.hidden_states[self.select_layer]
41 | if self.select_feature == 'patch':
42 | image_features = image_features[:, 1:]
43 | elif self.select_feature == 'cls_patch':
44 | image_features = image_features
45 | else:
46 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
47 | return image_features
48 |
49 | @torch.no_grad()
50 | def forward(self, images):
51 | if type(images) is list:
52 | image_features = []
53 | for image in images:
54 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
55 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
56 | image_features.append(image_feature)
57 | else:
58 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
59 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
60 |
61 | return image_features
62 |
63 | @property
64 | def dummy_feature(self):
65 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
66 |
67 | @property
68 | def dtype(self):
69 | return self.vision_tower.dtype
70 |
71 | @property
72 | def device(self):
73 | return self.vision_tower.device
74 |
75 | @property
76 | def config(self):
77 | if self.is_loaded:
78 | return self.vision_tower.config
79 | else:
80 | return self.cfg_only
81 |
82 | @property
83 | def hidden_size(self):
84 | return self.config.hidden_size
85 |
86 | @property
87 | def num_patches_per_side(self):
88 | return self.config.image_size // self.config.patch_size
89 |
90 | @property
91 | def num_patches(self):
92 | return (self.config.image_size // self.config.patch_size) ** 2
93 |
--------------------------------------------------------------------------------
/src/model/multimodal_encoder/openclip_encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import os
5 | import json
6 | import logging
7 | import deepspeed
8 | from pathlib import Path
9 | from open_clip.factory import load_state_dict, get_model_config
10 | from open_clip.model import CLIPVisionCfg, CLIPTextCfg, _build_vision_tower, convert_to_custom_text_state_dict, resize_pos_embed
11 | from typing import Dict, Optional
12 | from transformers.deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
13 | # from transformers import CLIPImageProcessor
14 | from .openclip_processor import OpenCLIPImageProcessor
15 |
16 |
17 | class OpenCLIPVisionTower(nn.Module):
18 | def __init__(self, vision_tower, args, delay_load=False):
19 | super().__init__()
20 |
21 | self.is_loaded = False
22 | self.vision_tower_name = vision_tower
23 | self.select_stage = args.mm_vision_select_layer if args.mm_vision_select_layer > -5 else -2 # the output stage to select for extracted features
24 | self.vision_config = json.load(open(os.path.join(vision_tower,'open_clip_config.json'), 'r'))
25 | self.is_optimize = getattr(args, 'optimize_vision_tower_aux', False)
26 |
27 | if not delay_load:
28 | self.load_model()
29 |
30 | def load_model(self):
31 | self.image_processor = OpenCLIPImageProcessor.from_pretrained('/public/models/clip/clip-vit-large-patch14')
32 | ckpt_path = os.path.join(self.vision_tower_name, 'open_clip_pytorch_model.bin')
33 | if 'convnext' in self.vision_tower_name:
34 | if 'large' in self.vision_tower_name and 'd_320' in self.vision_tower_name:
35 | self.model_type = 'convnext_large_d_320'
36 | self.model_channel = [192, 384, 768, 1536] # stage 0-3
37 | elif 'base' in self.vision_tower_name and 'w_320' in self.vision_tower_name:
38 | self.model_type = 'convnext_base_w_320'
39 | self.model_channel = [128, 256, 512, 1024]
40 | elif 'xxlarge' in self.vision_tower_name:
41 | self.model_type = 'convnext_xxlarge'
42 | self.model_channel = [384, 768, 1536, 3072]
43 |
44 | clip_model = CLIP(**get_model_config(self.model_type))
45 | clip_model.visual.trunk.norm_pre = None
46 | clip_model.visual.trunk.head = None
47 | clip_model.visual.head = None
48 | print(f'Loading pretrained weights ({self.model_type}).')
49 | load_checkpoint(clip_model, ckpt_path, strict=False)
50 |
51 | self.is_loaded = True
52 | # decompose stem and stages blocks in vision tower
53 | self.vision_stem = clip_model.visual.trunk.stem
54 | self.vision_stages = clip_model.visual.trunk.stages
55 | self.vision_stem.requires_grad_(False)
56 | self.vision_stages.requires_grad_(False)
57 |
58 | def forward(self, images):
59 | if type(images) is list:
60 | image_features = []
61 | for image in images:
62 | image_feature = self.backbone(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
63 | image_features.append(image_feature)
64 | else:
65 | image_features = self.backbone(images.to(device=self.device, dtype=self.dtype))
66 |
67 | return image_features
68 |
69 | def backbone(self, images):
70 | # print('images: ', images.shape)
71 | if not self.is_optimize:
72 | with torch.no_grad():
73 | results = self.basic_forward(images)
74 | else:
75 | results = self.basic_forward(images)
76 | # 448- torch.Size([1, 384, 56, 56]), torch.Size([1, 768, 28, 28]), torch.Size([1, 1536, 14, 14]), torch.Size([1, 3072, 7, 7])]
77 | # 672- torch.Size([16, 384, 168, 168]), torch.Size([16, 768, 84, 84]), torch.Size([16, 1536, 42, 42]), torch.Size([16, 3072, 21, 21])
78 | # where hidden_size = sum(model_channel)
79 | # print('results: ', [results[_stage].shape for _stage in results])
80 |
81 | # target_size = (results['stage_0'].shape[-2], results['stage_0'].shape[-1])
82 | # result_cat = []
83 | # for _stage in results:
84 | # if _stage == 'stage_0':
85 | # result_cat.append(results[_stage].contiguous())
86 | # else:
87 | # result_cat.append(F.interpolate(results[_stage].float().contiguous() ,
88 | # size=target_size,
89 | # mode='bilinear',
90 | # align_corners=False).to(dtype=results[_stage].dtype))
91 | # result_cat = torch.cat(result_cat, dim=1)
92 | select_stage = f'stage_{4+self.select_stage}'
93 | result_cat = results[select_stage]
94 | # print("result_cat: ", result_cat.shape)
95 |
96 | return result_cat.contiguous()
97 |
98 | def basic_forward(self, images):
99 | results = {}
100 | x = self.vision_stem(images)
101 | for _idx in range(len(self.vision_stages)):
102 | x = self.vision_stages[_idx](x)
103 | results[f'stage_{_idx}'] = x
104 | return results
105 |
106 | @property
107 | def dummy_feature(self):
108 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
109 |
110 | @property
111 | def dtype(self):
112 | return self.vision_stem[0].weight.dtype
113 |
114 | @property
115 | def device(self):
116 | return self.vision_stem[0].weight.device
117 |
118 | @property
119 | def config(self):
120 | return self.vision_config
121 |
122 | @property
123 | def hidden_size(self):
124 | return self.model_channel[self.select_stage]
125 | # return sum(self.model_channel)
126 |
127 | # modified function from open_clip to support zero3 stage
128 | def load_checkpoint(model, checkpoint_path, strict=True):
129 | if Path(checkpoint_path).suffix in ('.npz', '.npy'):
130 | from open_clip.big_vision import load_big_vision_weights
131 | load_big_vision_weights(model, checkpoint_path)
132 | return {}
133 |
134 | state_dict = load_state_dict(checkpoint_path)
135 | # detect old format and make compatible with new format
136 | if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
137 | state_dict = convert_to_custom_text_state_dict(state_dict)
138 | # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712
139 | # if 'logit_bias' not in state_dict and model.logit_bias is not None:
140 | # state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"])
141 | # Certain text transformers no longer expect position_ids after transformers==4.31
142 | position_id_key = 'text.transformer.embeddings.position_ids'
143 | if position_id_key in state_dict and not hasattr(model, position_id_key):
144 | del state_dict[position_id_key]
145 | resize_pos_embed(state_dict, model)
146 | # resize_text_pos_embed(state_dict, model)
147 | #incompatible_keys = model.load_state_dict(state_dict, strict=strict)
148 | if is_deepspeed_zero3_enabled():
149 |
150 | error_msgs = []
151 |
152 | def load(module: nn.Module, state_dict, prefix=""):
153 | metadata = None
154 |
155 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
156 | args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
157 | # Parameters of module and children will start with prefix. We can exit early if there are none in this
158 | # state_dict
159 | if len([key for key in state_dict if key.startswith(prefix)]) > 0:
160 | if is_deepspeed_zero3_enabled():
161 | # In sharded models, each shard has only part of the full state_dict, so only gather
162 | # parameters that are in the current state_dict.
163 | named_parameters = dict(module.named_parameters(prefix=prefix[:-1], recurse=False))
164 | params_to_gather = [named_parameters[k] for k in state_dict.keys() if k in named_parameters]
165 | if len(params_to_gather) > 0:
166 | # because zero3 puts placeholders in model params, this context
167 | # manager gathers (unpartitions) the params of the current layer, then loads from
168 | # the state dict and then re-partitions them again
169 | with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
170 | if torch.distributed.get_rank() == 0:
171 | module._load_from_state_dict(*args)
172 | else:
173 | module._load_from_state_dict(*args)
174 |
175 | for name, child in module._modules.items():
176 | if child is not None:
177 | load(child, state_dict, prefix + name + ".")
178 |
179 | load(model, state_dict)
180 | incompatible_keys = []
181 | else:
182 | incompatible_keys = model.load_state_dict(state_dict, strict=strict)
183 | logging.info(f"incompatible_keys.missing_keys: {incompatible_keys.missing_keys}")
184 | return incompatible_keys
185 |
186 | class CLIP(nn.Module):
187 | output_dict: torch.jit.Final[bool]
188 |
189 | def __init__(
190 | self,
191 | embed_dim: int,
192 | vision_cfg: CLIPVisionCfg,
193 | text_cfg: CLIPTextCfg,
194 | quick_gelu: bool = False,
195 | cast_dtype: Optional[torch.dtype] = None,
196 | output_dict: bool = False,
197 | ):
198 | super().__init__()
199 | self.output_dict = output_dict
200 |
201 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
--------------------------------------------------------------------------------
/src/model/multimodal_encoder/openclip_processor.py:
--------------------------------------------------------------------------------
1 | from transformers import CLIPImageProcessor
2 | from transformers.image_processing_utils import BatchFeature, get_size_dict
3 | from transformers.image_transforms import get_resize_output_image_size
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | import numpy as np
9 |
10 |
11 | class OpenCLIPImageProcessor(CLIPImageProcessor):
12 |
13 | def __init__(self, **kwargs):
14 | super().__init__(**kwargs)
15 |
16 | def preprocess(self, images, **kwargs):
17 | if not isinstance(images, np.ndarray):
18 | return super().preprocess(images=images, **kwargs)
19 |
20 | do_resize = kwargs.get('do_resize', self.do_resize)
21 | size = kwargs.get('size', self.size)
22 | size = get_size_dict(size, param_name="size", default_to_square=False)
23 | do_center_crop = kwargs.get('do_center_crop', self.do_center_crop)
24 | crop_size = kwargs.get('crop_size', self.crop_size)
25 | crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True)
26 | do_rescale = kwargs.get('do_rescale', self.do_rescale)
27 | rescale_factor = kwargs.get('rescale_factor', self.rescale_factor)
28 | do_normalize = kwargs.get('do_normalize', self.do_normalize)
29 | image_mean = kwargs.get('image_mean', self.image_mean)
30 | image_std = kwargs.get('image_std', self.image_std)
31 | return_tensors = kwargs.get('return_tensors', None)
32 |
33 | def resize(images, output_size):
34 | images = images.permute((0, 3, 1, 2))
35 | images = F.interpolate(images, size=output_size, mode='bicubic')
36 | images = images.permute((0, 2, 3, 1))
37 | return images
38 |
39 | def center_crop(images, crop_size):
40 | crop_width, crop_height = crop_size["width"], crop_size["height"]
41 | img_width, img_height = images.shape[1:3]
42 | x = (img_width - crop_width) // 2
43 | y = (img_height - crop_height) // 2
44 | images = images[:, x:x+crop_width, y:y+crop_height]
45 | return images
46 |
47 | def rescale(images, rescale_factor):
48 | images = images * rescale_factor
49 | return images
50 |
51 | def normalize(images, mean, std):
52 | mean = torch.tensor(mean)
53 | std = torch.tensor(std)
54 | images = (images - mean) / std
55 | return images
56 |
57 | images = torch.from_numpy(images).float()
58 |
59 | if do_resize:
60 | output_size = get_resize_output_image_size(images[0], size=size["shortest_edge"], default_to_square=False)
61 | images = resize(images, output_size)
62 |
63 | if do_center_crop:
64 | images = center_crop(images, crop_size)
65 |
66 | if do_rescale:
67 | images = rescale(images, rescale_factor)
68 |
69 | if do_normalize:
70 | images = normalize(images, image_mean, image_std)
71 |
72 | images = images.permute((0, 3, 1, 2))
73 | data = {"pixel_values": images}
74 | return BatchFeature(data=data, tensor_type=return_tensors)
--------------------------------------------------------------------------------
/src/model/multimodal_generator/builder.py:
--------------------------------------------------------------------------------
1 | from dataclasses import asdict, is_dataclass
2 | from ..setok import SetokDeTokenizer
3 |
4 | def build_vision_generator(image_generator_cfg, **kwargs):
5 | if is_dataclass(image_generator_cfg):
6 | image_generator_cfg = asdict(image_generator_cfg)
7 | elif isinstance(image_generator_cfg, dict):
8 | image_generator_cfg = image_generator_cfg
9 | else:
10 | image_generator_cfg = vars(image_generator_cfg)
11 |
12 | return SetokDeTokenizer(**image_generator_cfg, **kwargs)
--------------------------------------------------------------------------------
/src/model/multimodal_projector/builder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import re
4 |
5 |
6 | class IdentityMap(nn.Module):
7 | def __init__(self):
8 | super().__init__()
9 |
10 | def forward(self, x, *args, **kwargs):
11 | return x
12 |
13 | @property
14 | def config(self):
15 | return {"mm_projector_type": 'identity'}
16 |
17 |
18 | class SimpleResBlock(nn.Module):
19 | def __init__(self, channels):
20 | super().__init__()
21 | self.pre_norm = nn.LayerNorm(channels)
22 |
23 | self.proj = nn.Sequential(
24 | nn.Linear(channels, channels),
25 | nn.GELU(),
26 | nn.Linear(channels, channels)
27 | )
28 | def forward(self, x):
29 | x = self.pre_norm(x)
30 | return x + self.proj(x)
31 |
32 |
33 | def build_vision_projector(projector_type='linear', mm_hidden_size=4096, hidden_size=3078, delay_load=False, **kwargs):
34 | # projector_type = getattr(config, 'mm_projector_type', 'linear')
35 |
36 | if projector_type == 'linear':
37 | return nn.Linear(mm_hidden_size, hidden_size)
38 |
39 | use_norm = False
40 | if "_Norm" in projector_type:
41 | use_norm = True
42 | projector_type = projector_type.replace("_Norm", "")
43 |
44 |
45 | mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
46 | if mlp_gelu_match:
47 | mlp_depth = int(mlp_gelu_match.group(1))
48 | if use_norm:
49 | modules = [
50 | nn.Linear(mm_hidden_size, hidden_size),
51 | nn.LayerNorm(hidden_size),
52 | ]
53 | else:
54 | modules = [nn.Linear(mm_hidden_size, hidden_size)]
55 | # modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
56 | for _ in range(1, mlp_depth):
57 | modules.append(nn.GELU())
58 | modules.append(nn.Linear(hidden_size, hidden_size))
59 | return nn.Sequential(*modules)
60 |
61 | if projector_type == 'identity':
62 | return IdentityMap()
63 |
64 | raise ValueError(f'Unknown projector type: {projector_type}')
65 |
--------------------------------------------------------------------------------
/src/model/setok/__init__.py:
--------------------------------------------------------------------------------
1 | from .tokenizer import SetokTokenizer
2 | from .detokenizer import SetokDeTokenizer
3 | from .model import SeTok
--------------------------------------------------------------------------------
/src/model/setok/clip_encoder.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import torch
3 | import torch.nn as nn
4 |
5 | from transformers import AutoModel, AutoProcessor, AutoConfig
6 |
7 |
8 | class CLIPVisionTower(nn.Module):
9 | def __init__(self, vision_tower: Optional[str],
10 | unfreeze_mm_vision_tower: Optional[str] = False,
11 | mm_vision_select_feature: Optional[str] = 'patch',
12 | mm_vision_select_layer: Optional[int] = -2,
13 | delay_load=False):
14 | super().__init__()
15 |
16 | self.is_loaded = False
17 |
18 | self.vision_tower_name = vision_tower
19 | self.select_layer = mm_vision_select_layer
20 | self.select_feature = mm_vision_select_feature
21 |
22 | if not delay_load:
23 | self.load_model()
24 | elif unfreeze_mm_vision_tower:
25 | self.load_model()
26 | else:
27 | self.cfg_only = AutoConfig.from_pretrained(self.vision_tower_name)
28 |
29 | def load_model(self, device_map=None):
30 | if self.is_loaded:
31 | print('{} is already loaded, `load_model` called again, skipping.'.format(self.vision_tower_name))
32 | return
33 |
34 | self.image_processor = AutoProcessor.from_pretrained(self.vision_tower_name)
35 | self.vision_tower = AutoModel.from_pretrained(self.vision_tower_name, device_map=device_map)
36 | self.vision_tower.requires_grad_(False)
37 |
38 | self.is_loaded = True
39 |
40 | def feature_select(self, image_forward_outs):
41 | image_features = image_forward_outs.hidden_states[self.select_layer]
42 | if self.select_feature == 'patch':
43 | image_features = image_features[:, 1:]
44 | elif self.select_feature == 'cls_patch':
45 | image_features = image_features
46 | else:
47 | raise ValueError(f'Unexpected select feature: {self.select_feature}')
48 | return image_features
49 |
50 | @torch.no_grad()
51 | def forward(self, images):
52 | if type(images) is list:
53 | image_features = []
54 | for image in images:
55 | image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
56 | image_feature = self.feature_select(image_forward_out).to(image.dtype)
57 | image_features.append(image_feature)
58 | else:
59 | image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
60 | image_features = self.feature_select(image_forward_outs).to(images.dtype)
61 |
62 | return image_features
63 |
64 | @property
65 | def dummy_feature(self):
66 | return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
67 |
68 | @property
69 | def dtype(self):
70 | return self.vision_tower.dtype
71 |
72 | @property
73 | def device(self):
74 | return self.vision_tower.device
75 |
76 | @property
77 | def config(self):
78 | if self.is_loaded:
79 | return self.vision_tower.config
80 | else:
81 | return self.cfg_only
82 |
83 | @property
84 | def hidden_size(self):
85 | return self.config.hidden_size
86 |
87 | @property
88 | def num_patches_per_side(self):
89 | return self.config.image_size // self.config.patch_size
90 |
91 | @property
92 | def num_patches(self):
93 | return (self.config.image_size // self.config.patch_size) ** 2
94 |
--------------------------------------------------------------------------------
/src/model/setok/detokenizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from typing import Union, Dict, List, Tuple, Optional
5 | from einops import rearrange, repeat
6 | from timm.models.vision_transformer import Block
7 | from torch.utils.checkpoint import checkpoint
8 | from transformers.models.bert import BertConfig
9 | from .module import BertModel, PositionalEncoding2D
10 | from diffusers.models.autoencoders.vae import Decoder
11 |
12 |
13 |
14 | class SetokDeTokenizer(nn.Module):
15 | def __init__(self,
16 | token_feat_dim: Optional[int] = 4096,
17 | hidden_dim: Optional[int] = 4096,
18 | patch_size: Optional[int]=14,
19 | image_size: Optional[int]=256,
20 | decoder_embed_dim: Optional[int]=4096,
21 | decoder_nheads: Optional[int]=16,
22 | proj_drop: Optional[float]=0.2,
23 | attn_drop: Optional[float]=0.2,
24 | decoder_depth: Optional[int]=16,
25 | norm_layer: nn.Module = nn.LayerNorm,
26 | mlp_ratio: Optional[float]=4.0,
27 | feature_mapper_path_or_name: Optional[str]="bert-base-uncased",
28 | num_hidden_layers: Optional[int]=6,
29 | cross_attention_freq: Optional[int]=2,
30 | initializer_range: Optional[float]=0.02,
31 | **kwargs) -> None:
32 | super().__init__()
33 | self.token_feat_dim = token_feat_dim
34 |
35 | self.patch_size = patch_size
36 | self.height = self.weight = image_size // patch_size
37 | self.num_mask_token = self.height * self.weight
38 | self.hidden_dim = hidden_dim
39 |
40 | query_tokens = nn.Parameter(torch.zeros(1, self.num_mask_token, self.hidden_dim))
41 | query_tokens.data.normal_(mean=0.0, std=initializer_range)
42 | self.mask_tokens = query_tokens
43 |
44 | self.decoder_embed_dim = decoder_embed_dim
45 | self.mapper_fc_in = nn.Linear(self.token_feat_dim, self.hidden_dim)
46 | self.decoder_fc_in = nn.Linear(self.hidden_dim, self.decoder_embed_dim)
47 |
48 | self.decoder_norm = norm_layer(self.decoder_embed_dim)
49 | self.pixel_decoder = nn.ModuleList([
50 | Block(self.decoder_embed_dim, decoder_nheads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, proj_drop=proj_drop, attn_drop=attn_drop) for _ in range(decoder_depth)
51 | ])
52 | self.position_embedding = PositionalEncoding2D(self.hidden_dim)
53 | self.initialize_weights()
54 | self.init_feature_mapper(feature_mapper_path_or_name, self.hidden_dim, self.num_mask_token, num_hidden_layers, cross_attention_freq)
55 |
56 | def initialize_weights(self):
57 | self.apply(self._init_weights)
58 |
59 | def _init_weights(self, m):
60 | if isinstance(m, nn.Linear):
61 | # we use xavier_uniform following official JAX ViT:
62 | torch.nn.init.xavier_uniform_(m.weight)
63 | if isinstance(m, nn.Linear) and m.bias is not None:
64 | nn.init.constant_(m.bias, 0)
65 | elif isinstance(m, nn.LayerNorm):
66 | if m.bias is not None:
67 | nn.init.constant_(m.bias, 0)
68 | if m.weight is not None:
69 | nn.init.constant_(m.weight, 1.0)
70 |
71 | def init_feature_mapper(
72 | self,
73 | feature_mapper_path_or_name: str,
74 | vision_width: int,
75 | num_mask_token: int,
76 | num_hidden_layers: int,
77 | cross_attention_freq: int
78 | ):
79 | print("feature_mapper_path_or_name: ", feature_mapper_path_or_name)
80 | mapper_config = BertConfig.from_pretrained(feature_mapper_path_or_name)
81 |
82 | mapper_config.encoder_width = vision_width
83 | # insert cross-attention layer every other block
84 | mapper_config.add_cross_attention = True
85 |
86 | mapper_config.cross_attention_freq = cross_attention_freq
87 | mapper_config.query_length = num_mask_token
88 | mapper_config.num_hidden_layers = num_hidden_layers
89 |
90 | self.mapper = BertModel.from_pretrained(feature_mapper_path_or_name, config=mapper_config)
91 | self.mapper.cls = None
92 | self.mapper.embeddings.word_embeddings = None
93 | self.mapper.embeddings.position_embeddings = None
94 | for layer in self.mapper.encoder.layer:
95 | layer.output = None
96 | layer.intermediate = None
97 |
98 | def load_model(self):
99 | pass
100 |
101 | def forward(self, x, attention_masks):
102 |
103 | mask_tokens = self.mask_tokens.expand(x.shape[0], -1, -1)
104 | x = self.mapper_fc_in(x)
105 | x = self.mapper(
106 | query_embeds=mask_tokens,
107 | encoder_hidden_states=x,
108 | encoder_attention_mask=attention_masks,
109 | return_dict=True).last_hidden_state
110 |
111 | x = self.decoder_fc_in(x) # b, h*w, c
112 | _x = rearrange(x, 'B (h w) C -> B h w C', h=self.height, w=self.weight)
113 | pos_emb = self.position_embedding(_x)
114 | pos_emb = rearrange(pos_emb, 'B h w C -> B (h w) C')
115 | x = x + pos_emb
116 |
117 | for block in self.pixel_decoder:
118 | x = block(x)
119 |
120 | x = self.decoder_norm(x)
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/src/model/setok/model.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as f
5 | from typing import Union, Dict, List,Tuple, Optional
6 | import numpy as np
7 | from dataclasses import dataclass, asdict, is_dataclass
8 | from transformers.utils import ModelOutput
9 | from .utils import *
10 | from .tokenizer import SetokTokenizer
11 | from .detokenizer import SetokDeTokenizer
12 | from ..loss import GANLoss, MultilabelContrastiveLoss
13 | from ..utils import instantiate_from_config
14 |
15 |
16 | @dataclass
17 | class SetokOutput(ModelOutput):
18 | token_emb: torch.FloatTensor = None
19 | predict_emb: torch.FloatTensor = None
20 | loss: torch.FloatTensor = None,
21 | loss_log: Dict = None,
22 |
23 |
24 |
25 | class SeTok(nn.Module):
26 | def __init__(self,
27 | tokenizer_config,
28 | detokenizer_config,
29 | rec_loss_config=None,
30 | contrastive_loss_config=None,
31 | is_training=False,
32 | **kwargs) -> None:
33 | super(SeTok).__init__()
34 |
35 | self.tokenizer_config = tokenizer_config
36 | self.detokenizer_config = detokenizer_config
37 |
38 | self.tokenizer = SetokTokenizer(**asdict(tokenizer_config))
39 |
40 | self.detokenizer = SetokDeTokenizer(**asdict(detokenizer_config))
41 |
42 | self.is_training = is_training
43 | if is_training:
44 | self.rec_loss = GANLoss(**asdict(rec_loss_config))
45 | self.contrastive_loss = MultilabelContrastiveLoss(**asdict(contrastive_loss_config))
46 |
47 | def get_tokenizer_config(self):
48 | return self.tokenizer_config
49 |
50 | def get_detokenizer_config(self):
51 | return self.detokenizer_config
52 |
53 | def get_tokenizer(self):
54 | return self.tokenizer
55 |
56 | def get_detokenizer(self):
57 | return self.detokenizer
58 |
59 | def tokenize(self, x):
60 | return self.tokenizer(x)
61 |
62 | def detokenize(self, x):
63 | return self.detokenizer(x)
64 |
65 | def sample_orders(self, bsz):
66 | # generate a batch of random generation orders
67 | orders = []
68 | for _ in range(bsz):
69 | order = np.array(list(range(self.seq_len)))
70 | np.random.shuffle(order)
71 | orders.append(order)
72 | orders = torch.Tensor(np.array(orders)).cuda().long()
73 | return orders
74 |
75 |
76 |
77 | def compute_rec_loss(self, prediction, target, current_step):
78 | loss, log_dict = self.rec_loss(target, prediction, current_step)
79 | return loss, log_dict
80 |
81 | def compute_contrastive_loss(self, prediction, text):
82 | loss, log_dict = self.contrastive_loss(prediction, text)
83 | return loss, log_dict
84 |
85 |
86 | def forward(self, x, gold_image=None, text=None, return_dict=True, current_step=0):
87 | e_tokens, _, _ = self.tokenize(x)
88 | prediction = self.detokenize(e_tokens)
89 | loss = None
90 | loss_log = dict()
91 | if gold_image is not None:
92 | rec_loss, rec_loss_log = self.compute_rec_loss(prediction, gold_image, current_step)
93 | loss = rec_loss
94 | loss_log.update(**rec_loss_log)
95 |
96 | if text is not None:
97 | txt_contrastive_loss, txt_contrastive_log = self.compute_contrastive_loss(e_tokens, text)
98 | loss += txt_contrastive_loss
99 | loss_log.update(**txt_contrastive_log)
100 | if return_dict:
101 | SetokOutput(token_emb=e_tokens, predict_emb=prediction, loss=loss, loss_log=loss_log)
102 | else:
103 | loss, (e_tokens, prediction, loss_log)
104 |
105 |
106 |
107 |
--------------------------------------------------------------------------------
/src/model/setok/tokenizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from typing import Union, Dict, List, Optional
5 | import numpy as np
6 | from einops import rearrange, repeat
7 | from timm.models.layers import DropPath
8 | import math
9 | from .clip_encoder import CLIPVisionTower
10 | from .module import Block, PositionalEncoding2D
11 |
12 |
13 | class SetokTokenizer(nn.Module):
14 | def __init__(self,
15 | vision_tower: str = 'google/siglip-so400m-patch14-384',
16 | unfreeze_mm_vision_tower: Optional[bool] = False,
17 | mm_vision_select_feature: Optional[str] = 'patch',
18 | mm_vision_select_layer: Optional[int] = -2,
19 | delay_load: Optional[bool]= False,
20 | hidden_dim: Optional[int] = 4096,
21 | token_feat_dim: Optional[int] = 4096,
22 | min_cluster_num: Optional[int] = 64,
23 | threshold: Optional[float] = 0.5,
24 | nheads: Optional[int] = 2,
25 | dim_feedforward: Optional[int] = 4096,
26 | proj_drop: Optional[float] = 0.2,
27 | drop_path: Optional[float] = 0.0,
28 | inner_cluster_layers: Optional[int] = 2,
29 | intra_cluster_layers: Optional[int] = 2,
30 | attn_drop: Optional[float] = 0.0,
31 | act_layer: nn.Module = nn.GELU,
32 | norm_layer: nn.Module = nn.LayerNorm,
33 | **kwargs
34 | ) -> None:
35 | super().__init__()
36 |
37 | self.hidden_dim = hidden_dim
38 | self.token_feat_dim = token_feat_dim
39 |
40 | self.inner_encoder = Block(self.hidden_dim, nheads, dim_feedforward, proj_drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer, depth=inner_cluster_layers)
41 | self.inter_encoder = Block(self.hidden_dim, nheads, dim_feedforward, proj_drop=proj_drop, attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer, depth=intra_cluster_layers)
42 | self.position_embedding = PositionalEncoding2D(self.hidden_dim)
43 | self.out = nn.Linear(self.hidden_dim, self.token_feat_dim)
44 |
45 | self.min_cluster_num = min_cluster_num
46 | self.threshold = threshold
47 |
48 | self.initialize_weights()
49 |
50 | self.image_feature_encoder = CLIPVisionTower(vision_tower,
51 | unfreeze_mm_vision_tower=unfreeze_mm_vision_tower,
52 | mm_vision_select_feature=mm_vision_select_feature,
53 | mm_vision_select_layer=mm_vision_select_layer,
54 | delay_load=delay_load
55 | )
56 | self.image_processor = self.image_feature_encoder.image_processor
57 |
58 |
59 | def initialize_weights(self):
60 | self.apply(self._init_weights)
61 |
62 | def _init_weights(self, m):
63 | if isinstance(m, nn.Linear):
64 | # we use xavier_uniform following official JAX ViT:
65 | torch.nn.init.xavier_uniform_(m.weight)
66 | if isinstance(m, nn.Linear) and m.bias is not None:
67 | nn.init.constant_(m.bias, 0)
68 | elif isinstance(m, nn.LayerNorm):
69 | if m.bias is not None:
70 | nn.init.constant_(m.bias, 0)
71 | if m.weight is not None:
72 | nn.init.constant_(m.weight, 1.0)
73 |
74 | @property
75 | def dtype(self):
76 | return self.Linear.weight.dtype
77 |
78 | def cluster_dpc_knn(self, x, k, token_mask=None, threshold=0.53):
79 | with torch.no_grad():
80 | N, C = x.shape
81 |
82 | dist_matrix = torch.cdist(x, x) / (C ** 0.5) # C * C
83 |
84 | if token_mask is not None:
85 | token_mask = token_mask > 0
86 | dist_matrix = dist_matrix * token_mask[None, :] + (dist_matrix.max() + 1) * (~token_mask[None, :])
87 |
88 | dist_nearest, index_nearest = torch.topk(dist_matrix, k=k, dim=-1, largest=False) # C * k
89 |
90 | density = (-(dist_nearest ** 2).mean(dim=-1)).exp() # C
91 | density = density + torch.rand(density.shape, device=density.device, dtype=density.dtype) * 1e-6 # C
92 |
93 | if token_mask is not None:
94 | density = density * token_mask
95 |
96 | mask = density[None, :] > density[:, None] # C * C
97 | mask = mask.type(x.dtype)
98 | dist_max = dist_matrix.flatten(1).max(dim=-1)[0][None, None] # C * C
99 | dist, index_parent = (dist_matrix * mask + dist_max * (1 - mask)).min(dim=-1) # 1 * C, 1 * C
100 |
101 | score = dist * density
102 |
103 | index_down = torch.nonzero(score.reshape(-1)>threshold).reshape(-1) # obtain the index of the center
104 | if index_down.numel() == 0:
105 | _, index_down = torch.topk(score, k=self.min_cluster_num, dim=-1)
106 | index_down = torch.sort(index_down).values
107 | index_down = index_down.reshape(-1)
108 |
109 | # obtain the index of the cluster that each token belongs to
110 | # dist_matrix = index_points(dist_matrix, index_down.squeeze()) # the cluster_num * C
111 | dist_matrix = dist_matrix[index_down, :] # the cluster_num * C
112 |
113 | idx_cluster = dist_matrix.argmin(dim=0) # the cluster_num
114 |
115 | # B = 1
116 | # idx_batch = torch.arange(B, device=x.device)[:, None].expand(cluster_num)
117 | cluster_num = index_down.size(0)
118 | idx_tmp = torch.arange(cluster_num, device=x.device)[None, :]
119 | idx_cluster[index_down] = idx_tmp.reshape(-1)
120 |
121 | return index_down, idx_cluster, score
122 |
123 | def group_encoding(self, x, centers, labels):
124 | """
125 | We apply transformer within each group to modeling the features.
126 | Specifically, we take the center representation as the initial representation of CLS.
127 | Then, we take the CLS representation as the final representation of the group, i.e., the concept-level visual token.
128 | Args:
129 | x of size (W, C),
130 | centers of size (L, C)
131 | label of size (W)
132 |
133 | Return:
134 | Output: group features of size (L, C)
135 | """
136 |
137 | W, C = x.size()
138 | L, _ = centers.size()
139 |
140 | # Compute masks for each unique label
141 | unique_labels, label_counts = labels.unique(return_counts=True)
142 | # print('unique_labels: ', unique_labels)
143 | masks = [labels == cur_label for cur_label in unique_labels] # L, W
144 | # centers = centers.index_select(1, unique_labels)
145 |
146 | group_features = []
147 | for i, m in enumerate(masks):
148 | # _m = m.unsqueeze(1).expand(W, C)
149 | # cur_length = torch.sum(m).item()
150 | _cur_cluster_feat = self.inner_encoder(x[m].unsqueeze(0))
151 | _cur_cluster_feat = _cur_cluster_feat.squeeze(0).mean(dim=0)
152 | group_features.append(_cur_cluster_feat)
153 | group_features = torch.stack(group_features, dim=0)
154 |
155 | return group_features
156 |
157 | def forward(self, x, k=None, threshold=None, token_mask=None):
158 | """
159 | Expected Input: x of size (B, h, w, C)
160 | """
161 | x = self.image_feature_encoder(x)
162 | x = x.unsqueeze(0)
163 | B, hw, C = x.shape
164 | h = w = int(math.sqrt(x.shape[1]))
165 | _x = rearrange(x, 'B (h w) C -> B h w C', h=h, w=w)
166 | pos_emb = self.position_embedding(_x)
167 | pos_emb = rearrange(pos_emb, 'B h w C -> B (h w) C')
168 | x = x + pos_emb
169 | x = x.squeeze(0)
170 |
171 | _threshold = threshold if threshold else self.threshold
172 | _k = k if k else self.min_cluster_num
173 |
174 | index_down, idx_cluster, score = self.cluster_dpc_knn(x, _k, token_mask, _threshold)
175 | # index_down: the center index w.r.t the input x for each cluster
176 | # idx_cluster: the cluster index for each token in x, which is the index of the center in list [index_down]
177 | centers = x[index_down, :]
178 | group_features = self.group_encoding(x, centers, idx_cluster)
179 | group_features = self.inter_encoder(group_features)
180 | group_features = self.out(group_features)
181 |
182 | return group_features, idx_cluster, score
--------------------------------------------------------------------------------
/src/model/setok/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import torch
3 |
4 |
5 | def get_emb(sin_inp):
6 | """
7 | Gets a base embedding for one dimension with sin and cos intertwined
8 | """
9 | emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
10 | return torch.flatten(emb, -2, -1)
11 |
12 |
13 |
14 | def mask_by_order(mask_len, order, bsz, seq_len):
15 | masking = torch.zeros(bsz, seq_len).cuda()
16 | masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()
17 | return masking
--------------------------------------------------------------------------------
/src/model/utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig
2 | import importlib
3 |
4 |
5 |
6 | def instantiate_from_config(config):
7 | if not "target" in config:
8 | if config == '__is_first_stage__':
9 | return None
10 | elif config == "__is_unconditional__":
11 | return None
12 | raise KeyError("Expected key `target` to instantiate.")
13 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
14 |
15 |
16 | def get_obj_from_str(string, reload=False):
17 | module, cls = string.rsplit(".", 1)
18 | if reload:
19 | module_imp = importlib.import_module(module)
20 | importlib.reload(module_imp)
21 | return getattr(importlib.import_module(module, package=None), cls)
22 |
23 |
24 | def auto_upgrade(config):
25 | cfg = AutoConfig.from_pretrained(config)
26 | if 'llava' in config and 'llava' not in cfg.model_type:
27 | assert cfg.model_type == 'llama'
28 | print("You are using newer LLaVA code base, while the checkpoint of v0 is from older code base.")
29 | print("You must upgrade the checkpoint to the new code base (this can be done automatically).")
30 | confirm = input("Please confirm that you want to upgrade the checkpoint. [Y/N]")
31 | if confirm.lower() in ["y", "yes"]:
32 | print("Upgrading checkpoint...")
33 | assert len(cfg.architectures) == 1
34 | setattr(cfg.__class__, "model_type", "llava")
35 | cfg.architectures[0] = 'LlavaLlamaForCausalLM'
36 | cfg.save_pretrained(config)
37 | print("Checkpoint upgraded.")
38 | else:
39 | print("Checkpoint upgrade aborted.")
40 | exit(1)
41 |
--------------------------------------------------------------------------------
/src/train/llama_flash_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 | import warnings
3 |
4 | import torch
5 |
6 | import transformers
7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
8 |
9 | try:
10 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
11 | except ImportError:
12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
13 | from flash_attn.bert_padding import unpad_input, pad_input
14 |
15 |
16 | def forward(
17 | self,
18 | hidden_states: torch.Tensor,
19 | attention_mask: Optional[torch.Tensor] = None,
20 | position_ids: Optional[torch.Tensor] = None,
21 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
22 | output_attentions: bool = False,
23 | use_cache: bool = False,
24 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
25 | if output_attentions:
26 | warnings.warn(
27 | "Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
28 | )
29 |
30 | bsz, q_len, _ = hidden_states.size()
31 |
32 | query_states = (
33 | self.q_proj(hidden_states)
34 | .view(bsz, q_len, self.num_heads, self.head_dim)
35 | .transpose(1, 2)
36 | )
37 | key_states = (
38 | self.k_proj(hidden_states)
39 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
40 | .transpose(1, 2)
41 | )
42 | value_states = (
43 | self.v_proj(hidden_states)
44 | .view(bsz, q_len, self.num_key_value_heads, self.head_dim)
45 | .transpose(1, 2)
46 | ) # shape: (b, num_heads, s, head_dim)
47 |
48 | kv_seq_len = key_states.shape[-2]
49 | if past_key_value is not None:
50 | kv_seq_len += past_key_value[0].shape[-2]
51 |
52 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
53 | query_states, key_states = apply_rotary_pos_emb(
54 | query_states, key_states, cos, sin, position_ids
55 | )
56 |
57 | if past_key_value is not None:
58 | # reuse k, v
59 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
60 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
61 |
62 | past_key_value = (key_states, value_states) if use_cache else None
63 |
64 | # repeat k/v heads if n_kv_heads < n_heads
65 | key_states = repeat_kv(key_states, self.num_key_value_groups)
66 | value_states = repeat_kv(value_states, self.num_key_value_groups)
67 |
68 | # Transform the data into the format required by flash attention
69 | qkv = torch.stack([query_states, key_states, value_states], dim=2)
70 | qkv = qkv.transpose(1, 3) # shape: [b, s, 3, num_heads, head_dim]
71 | key_padding_mask = attention_mask
72 |
73 | if key_padding_mask is None:
74 | qkv = qkv.reshape(-1, 3, self.num_heads, self.head_dim)
75 | cu_q_lens = torch.arange(
76 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device
77 | )
78 | max_s = q_len
79 | output = flash_attn_unpadded_qkvpacked_func(
80 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
81 | )
82 | output = output.view(bsz, q_len, -1)
83 | else:
84 | qkv = qkv.reshape(bsz, q_len, -1)
85 | qkv, indices, cu_q_lens, max_s = unpad_input(qkv, key_padding_mask)
86 | qkv = qkv.view(-1, 3, self.num_heads, self.head_dim)
87 | output_unpad = flash_attn_unpadded_qkvpacked_func(
88 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
89 | )
90 | output_unpad = output_unpad.reshape(-1, self.num_heads * self.head_dim)
91 | output = pad_input(output_unpad, indices, bsz, q_len)
92 |
93 | return self.o_proj(output), None, past_key_value
94 |
95 |
96 | # Disable the transformation of the attention mask in LlamaModel as the flash attention
97 | # requires the attention mask to be the same as the key_padding_mask
98 | def _prepare_decoder_attention_mask(
99 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length
100 | ):
101 | # [bsz, seq_len]
102 | return attention_mask
103 |
104 |
105 | def replace_llama_attn_with_flash_attn():
106 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
107 | if cuda_major < 8:
108 | warnings.warn(
109 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
110 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
111 | )
112 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
113 | _prepare_decoder_attention_mask
114 | )
115 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
116 |
--------------------------------------------------------------------------------
/src/train/llama_xformers_attn_monkey_patch.py:
--------------------------------------------------------------------------------
1 | """
2 | Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
3 | """
4 |
5 | import logging
6 | import math
7 | from typing import Optional, Tuple
8 |
9 | import torch
10 | import transformers.models.llama.modeling_llama
11 | from torch import nn
12 |
13 | try:
14 | import xformers.ops
15 | except ImportError:
16 | logging.error("xformers not found! Please install it before trying to use it.")
17 |
18 |
19 | def replace_llama_attn_with_xformers_attn():
20 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
21 |
22 |
23 | def xformers_forward(
24 | self,
25 | hidden_states: torch.Tensor,
26 | attention_mask: Optional[torch.Tensor] = None,
27 | position_ids: Optional[torch.LongTensor] = None,
28 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
29 | output_attentions: bool = False,
30 | use_cache: bool = False,
31 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
32 | # pylint: disable=duplicate-code
33 | bsz, q_len, _ = hidden_states.size()
34 |
35 | query_states = (
36 | self.q_proj(hidden_states)
37 | .view(bsz, q_len, self.num_heads, self.head_dim)
38 | .transpose(1, 2)
39 | )
40 | key_states = (
41 | self.k_proj(hidden_states)
42 | .view(bsz, q_len, self.num_heads, self.head_dim)
43 | .transpose(1, 2)
44 | )
45 | value_states = (
46 | self.v_proj(hidden_states)
47 | .view(bsz, q_len, self.num_heads, self.head_dim)
48 | .transpose(1, 2)
49 | )
50 |
51 | kv_seq_len = key_states.shape[-2]
52 | if past_key_value is not None:
53 | kv_seq_len += past_key_value[0].shape[-2]
54 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
55 | (
56 | query_states,
57 | key_states,
58 | ) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
59 | query_states, key_states, cos, sin, position_ids
60 | )
61 | # [bsz, nh, t, hd]
62 |
63 | if past_key_value is not None:
64 | # reuse k, v, self_attention
65 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
66 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
67 |
68 | past_key_value = (key_states, value_states) if use_cache else None
69 |
70 | # We only apply xformers optimizations if we don't need to output the whole attention matrix
71 | if not output_attentions:
72 | query_states = query_states.transpose(1, 2)
73 | key_states = key_states.transpose(1, 2)
74 | value_states = value_states.transpose(1, 2)
75 |
76 | # This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
77 | # We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
78 | if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
79 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
80 | attn_output = xformers.ops.memory_efficient_attention(
81 | query_states, key_states, value_states, attn_bias=None
82 | )
83 | else:
84 | # input and output should be of form (bsz, q_len, num_heads, head_dim)
85 | attn_output = xformers.ops.memory_efficient_attention(
86 | query_states,
87 | key_states,
88 | value_states,
89 | attn_bias=xformers.ops.LowerTriangularMask(),
90 | )
91 | attn_weights = None
92 | else:
93 | attn_weights = torch.matmul(
94 | query_states, key_states.transpose(2, 3)
95 | ) / math.sqrt(self.head_dim)
96 |
97 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
98 | raise ValueError(
99 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
100 | f" {attn_weights.size()}"
101 | )
102 |
103 | if attention_mask is not None:
104 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
105 | raise ValueError(
106 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
107 | )
108 | attn_weights = attn_weights + attention_mask
109 | attn_weights = torch.max(
110 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
111 | )
112 |
113 | # upcast attention to fp32
114 | attn_weights = nn.functional.softmax(
115 | attn_weights, dim=-1, dtype=torch.float32
116 | ).to(query_states.dtype)
117 | attn_output = torch.matmul(attn_weights, value_states)
118 |
119 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
120 | raise ValueError(
121 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
122 | f" {attn_output.size()}"
123 | )
124 |
125 | attn_output = attn_output.transpose(1, 2)
126 |
127 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
128 | attn_output = self.o_proj(attn_output)
129 | return attn_output, attn_weights, past_key_value
130 |
--------------------------------------------------------------------------------
/src/train/setok_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import math
5 | from torch.utils.data import Sampler
6 | import time
7 | import sys
8 |
9 | from transformers import Trainer
10 | from transformers.trainer import (
11 | is_sagemaker_mp_enabled,
12 | get_parameter_names,
13 | has_length,
14 | ALL_LAYERNORM_LAYERS,
15 | logger,
16 | )
17 | # from transformers.trainer import *
18 | from typing import List, Optional, Dict, Union, Any
19 |
20 |
21 | def maybe_zero_3(param, ignore_status=False, name=None):
22 | from deepspeed import zero
23 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
24 | if hasattr(param, "ds_id"):
25 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
26 | if not ignore_status:
27 | print(name, 'no ignore status')
28 | with zero.GatheredParameters([param]):
29 | param = param.data.detach().cpu().clone()
30 | else:
31 | param = param.detach().cpu().clone()
32 | return param
33 |
34 |
35 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
36 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
37 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
38 | return to_return
39 |
40 |
41 | def split_to_even_chunks(indices, lengths, num_chunks):
42 | """
43 | Split a list of indices into `chunks` chunks of roughly equal lengths.
44 | """
45 |
46 | if len(indices) % num_chunks != 0:
47 | return [indices[i::num_chunks] for i in range(num_chunks)]
48 |
49 | num_indices_per_chunk = len(indices) // num_chunks
50 |
51 | chunks = [[] for _ in range(num_chunks)]
52 | chunks_lengths = [0 for _ in range(num_chunks)]
53 | for index in indices:
54 | shortest_chunk = chunks_lengths.index(min(chunks_lengths))
55 | chunks[shortest_chunk].append(index)
56 | chunks_lengths[shortest_chunk] += lengths[index]
57 | if len(chunks[shortest_chunk]) == num_indices_per_chunk:
58 | chunks_lengths[shortest_chunk] = float("inf")
59 |
60 | return chunks
61 |
62 |
63 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
64 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
65 | assert all(l != 0 for l in lengths), "Should not have zero length."
66 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
67 | # all samples are in the same modality
68 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
69 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
70 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
71 |
72 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
73 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
74 | megabatch_size = world_size * batch_size
75 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
76 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
77 |
78 | last_mm = mm_megabatches[-1]
79 | last_lang = lang_megabatches[-1]
80 | additional_batch = last_mm + last_lang
81 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
82 | megabatch_indices = torch.randperm(len(megabatches), generator=generator)
83 | megabatches = [megabatches[i] for i in megabatch_indices]
84 |
85 | if len(additional_batch) > 0:
86 | megabatches.append(sorted(additional_batch))
87 |
88 | return [i for megabatch in megabatches for i in megabatch]
89 |
90 |
91 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
92 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
93 | indices = torch.randperm(len(lengths), generator=generator)
94 | megabatch_size = world_size * batch_size
95 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
96 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
97 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
98 |
99 | return [i for megabatch in megabatches for batch in megabatch for i in batch]
100 |
101 |
102 | class LengthGroupedSampler(Sampler):
103 | r"""
104 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
105 | keeping a bit of randomness.
106 | """
107 |
108 | def __init__(
109 | self,
110 | batch_size: int,
111 | world_size: int,
112 | lengths: Optional[List[int]] = None,
113 | generator=None,
114 | group_by_modality: bool = False,
115 | ):
116 | if lengths is None:
117 | raise ValueError("Lengths must be provided.")
118 |
119 | self.batch_size = batch_size
120 | self.world_size = world_size
121 | self.lengths = lengths
122 | self.generator = generator
123 | self.group_by_modality = group_by_modality
124 |
125 | def __len__(self):
126 | return len(self.lengths)
127 |
128 | def __iter__(self):
129 | if self.group_by_modality:
130 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
131 | else:
132 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
133 | return iter(indices)
134 |
135 |
136 | class SetokTrainer(Trainer):
137 |
138 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
139 | if self.train_dataset is None or not has_length(self.train_dataset):
140 | return None
141 |
142 | if self.args.group_by_modality_length:
143 | lengths = self.train_dataset.modality_lengths
144 | return LengthGroupedSampler(
145 | self.args.train_batch_size,
146 | world_size=self.args.world_size * self.args.gradient_accumulation_steps,
147 | lengths=lengths,
148 | group_by_modality=True,
149 | )
150 | else:
151 | return super()._get_train_sampler()
152 |
153 | def create_optimizer(self):
154 | """
155 | Setup the optimizer.
156 |
157 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
158 | Trainer's init through `optimizers`, or subclass and override this method in a subclass.
159 | """
160 | if is_sagemaker_mp_enabled():
161 | return super().create_optimizer()
162 |
163 | opt_model = self.model
164 |
165 | if self.optimizer is None:
166 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
167 |
168 | self.optimizer = optimizer_cls(opt_model.parameters(), **optimizer_kwargs)
169 | if optimizer_cls.__name__ == "Adam8bit":
170 | import bitsandbytes
171 |
172 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
173 |
174 | skipped = 0
175 | for module in opt_model.modules():
176 | if isinstance(module, nn.Embedding):
177 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
178 | logger.info(f"skipped {module}: {skipped/2**20}M params")
179 | manager.register_module_override(module, "weight", {"optim_bits": 32})
180 | logger.debug(f"bitsandbytes: will optimize {module} in fp32")
181 | logger.info(f"skipped: {skipped/2**20}M params")
182 |
183 | return self.optimizer
184 |
185 | def _save_checkpoint(self, model, trial, metrics=None):
186 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
187 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
188 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
189 |
190 | run_dir = self._get_output_dir(trial=trial)
191 | output_dir = os.path.join(run_dir, checkpoint_folder)
192 |
193 | # Only save Adapter
194 | keys_to_match = ['mm_projector', 'vision_resampler']
195 | if getattr(self.args, "use_im_start_end", False):
196 | keys_to_match.extend(['embed_tokens', 'embed_in'])
197 |
198 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
199 |
200 | if self.args.local_rank == 0 or self.args.local_rank == -1:
201 | self.model.config.save_pretrained(output_dir)
202 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
203 | else:
204 | super(SetokimTrainer, self)._save_checkpoint(model, trial, metrics)
205 |
206 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
207 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
208 | pass
209 | else:
210 | super(SetokimTrainer, self)._save(output_dir, state_dict)
--------------------------------------------------------------------------------
/src/train/setokim_trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import math
5 | from torch.utils.data import Sampler
6 | import time
7 | import sys
8 |
9 | from transformers import Trainer
10 | from transformers.trainer import (
11 | is_sagemaker_mp_enabled,
12 | get_parameter_names,
13 | has_length,
14 | ALL_LAYERNORM_LAYERS,
15 | logger,
16 | )
17 | # from transformers.trainer import *
18 | from typing import List, Optional, Dict, Union, Any
19 |
20 |
21 | def maybe_zero_3(param, ignore_status=False, name=None):
22 | from deepspeed import zero
23 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
24 | if hasattr(param, "ds_id"):
25 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
26 | if not ignore_status:
27 | print(name, 'no ignore status')
28 | with zero.GatheredParameters([param]):
29 | param = param.data.detach().cpu().clone()
30 | else:
31 | param = param.detach().cpu().clone()
32 | return param
33 |
34 |
35 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
36 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
37 | to_return = {k: maybe_zero_3(v, ignore_status=True, name=k).cpu() for k, v in to_return.items()}
38 | return to_return
39 |
40 |
41 | def split_to_even_chunks(indices, lengths, num_chunks):
42 | """
43 | Split a list of indices into `chunks` chunks of roughly equal lengths.
44 | """
45 |
46 | if len(indices) % num_chunks != 0:
47 | return [indices[i::num_chunks] for i in range(num_chunks)]
48 |
49 | num_indices_per_chunk = len(indices) // num_chunks
50 |
51 | chunks = [[] for _ in range(num_chunks)]
52 | chunks_lengths = [0 for _ in range(num_chunks)]
53 | for index in indices:
54 | shortest_chunk = chunks_lengths.index(min(chunks_lengths))
55 | chunks[shortest_chunk].append(index)
56 | chunks_lengths[shortest_chunk] += lengths[index]
57 | if len(chunks[shortest_chunk]) == num_indices_per_chunk:
58 | chunks_lengths[shortest_chunk] = float("inf")
59 |
60 | return chunks
61 |
62 |
63 | def get_modality_length_grouped_indices(lengths, batch_size, world_size, generator=None):
64 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
65 | assert all(l != 0 for l in lengths), "Should not have zero length."
66 | if all(l > 0 for l in lengths) or all(l < 0 for l in lengths):
67 | # all samples are in the same modality
68 | return get_length_grouped_indices(lengths, batch_size, world_size, generator=generator)
69 | mm_indices, mm_lengths = zip(*[(i, l) for i, l in enumerate(lengths) if l > 0])
70 | lang_indices, lang_lengths = zip(*[(i, -l) for i, l in enumerate(lengths) if l < 0])
71 |
72 | mm_shuffle = [mm_indices[i] for i in get_length_grouped_indices(mm_lengths, batch_size, world_size, generator=None)]
73 | lang_shuffle = [lang_indices[i] for i in get_length_grouped_indices(lang_lengths, batch_size, world_size, generator=None)]
74 | megabatch_size = world_size * batch_size
75 | mm_megabatches = [mm_shuffle[i : i + megabatch_size] for i in range(0, len(mm_shuffle), megabatch_size)]
76 | lang_megabatches = [lang_shuffle[i : i + megabatch_size] for i in range(0, len(lang_shuffle), megabatch_size)]
77 |
78 | last_mm = mm_megabatches[-1]
79 | last_lang = lang_megabatches[-1]
80 | additional_batch = last_mm + last_lang
81 | megabatches = mm_megabatches[:-1] + lang_megabatches[:-1]
82 | megabatch_indices = torch.randperm(len(megabatches), generator=generator)
83 | megabatches = [megabatches[i] for i in megabatch_indices]
84 |
85 | if len(additional_batch) > 0:
86 | megabatches.append(sorted(additional_batch))
87 |
88 | return [i for megabatch in megabatches for i in megabatch]
89 |
90 |
91 | def get_length_grouped_indices(lengths, batch_size, world_size, generator=None, merge=True):
92 | # We need to use torch for the random part as a distributed sampler will set the random seed for torch.
93 | indices = torch.randperm(len(lengths), generator=generator)
94 | megabatch_size = world_size * batch_size
95 | megabatches = [indices[i : i + megabatch_size].tolist() for i in range(0, len(lengths), megabatch_size)]
96 | megabatches = [sorted(megabatch, key=lambda i: lengths[i], reverse=True) for megabatch in megabatches]
97 | megabatches = [split_to_even_chunks(megabatch, lengths, world_size) for megabatch in megabatches]
98 |
99 | return [i for megabatch in megabatches for batch in megabatch for i in batch]
100 |
101 |
102 | class LengthGroupedSampler(Sampler):
103 | r"""
104 | Sampler that samples indices in a way that groups together features of the dataset of roughly the same length while
105 | keeping a bit of randomness.
106 | """
107 |
108 | def __init__(
109 | self,
110 | batch_size: int,
111 | world_size: int,
112 | lengths: Optional[List[int]] = None,
113 | generator=None,
114 | group_by_modality: bool = False,
115 | ):
116 | if lengths is None:
117 | raise ValueError("Lengths must be provided.")
118 |
119 | self.batch_size = batch_size
120 | self.world_size = world_size
121 | self.lengths = lengths
122 | self.generator = generator
123 | self.group_by_modality = group_by_modality
124 |
125 | def __len__(self):
126 | return len(self.lengths)
127 |
128 | def __iter__(self):
129 | if self.group_by_modality:
130 | indices = get_modality_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
131 | else:
132 | indices = get_length_grouped_indices(self.lengths, self.batch_size, self.world_size, generator=self.generator)
133 | return iter(indices)
134 |
135 |
136 | class SetokimTrainer(Trainer):
137 |
138 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
139 | if self.train_dataset is None or not has_length(self.train_dataset):
140 | return None
141 |
142 | if self.args.group_by_modality_length:
143 | lengths = self.train_dataset.modality_lengths
144 | return LengthGroupedSampler(
145 | self.args.train_batch_size,
146 | world_size=self.args.world_size * self.args.gradient_accumulation_steps,
147 | lengths=lengths,
148 | group_by_modality=True,
149 | )
150 | else:
151 | return super()._get_train_sampler()
152 |
153 | def create_optimizer(self):
154 | """
155 | Setup the optimizer.
156 |
157 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
158 | Trainer's init through `optimizers`, or subclass and override this method in a subclass.
159 | """
160 | if is_sagemaker_mp_enabled():
161 | return super().create_optimizer()
162 |
163 | opt_model = self.model
164 |
165 | if self.optimizer is None:
166 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
167 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
168 | if self.args.mm_in_projector_lr is not None:
169 | projector_parameters = [name for name, _ in opt_model.named_parameters() if "mm_in_projector" in name]
170 | if self.args.mm_out_projector_lr is not None:
171 | projector_parameters.extend([name for name, _ in opt_model.named_parameters() if "mm_out_projector" in name ])
172 | optimizer_grouped_parameters = [
173 | {
174 | "params": [
175 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n not in projector_parameters and p.requires_grad)
176 | ],
177 | "weight_decay": self.args.weight_decay,
178 | },
179 | {
180 | "params": [
181 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n not in projector_parameters and p.requires_grad)
182 | ],
183 | "weight_decay": 0.0,
184 | },
185 | {
186 | "params": [
187 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and n in projector_parameters and p.requires_grad)
188 | ],
189 | "weight_decay": self.args.weight_decay,
190 | "lr": self.args.mm_projector_lr,
191 | },
192 | {
193 | "params": [
194 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and n in projector_parameters and p.requires_grad)
195 | ],
196 | "weight_decay": 0.0,
197 | "lr": self.args.mm_projector_lr,
198 | },
199 | ]
200 | else:
201 | optimizer_grouped_parameters = [
202 | {
203 | "params": [
204 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
205 | ],
206 | "weight_decay": self.args.weight_decay,
207 | },
208 | {
209 | "params": [
210 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
211 | ],
212 | "weight_decay": 0.0,
213 | },
214 | ]
215 |
216 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
217 |
218 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
219 | if optimizer_cls.__name__ == "Adam8bit":
220 | import bitsandbytes
221 | manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
222 |
223 | skipped = 0
224 | for module in opt_model.modules():
225 | if isinstance(module, nn.Embedding):
226 | skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
227 | logger.info(f"skipped {module}: {skipped/2**20}M params")
228 | manager.register_module_override(module, "weight", {"optim_bits": 32})
229 | logger.debug(f"bitsandbytes: will optimize {module} in fp32")
230 | logger.info(f"skipped: {skipped/2**20}M params")
231 |
232 | return self.optimizer
233 |
234 | def _save_checkpoint(self, model, trial, metrics=None):
235 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
236 | from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
237 | checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
238 |
239 | run_dir = self._get_output_dir(trial=trial)
240 | output_dir = os.path.join(run_dir, checkpoint_folder)
241 |
242 | # Only save Adapter
243 | keys_to_match = ['mm_in_projector', 'mm_out_projector']
244 | if getattr(self.args, "use_im_start_end", False):
245 | keys_to_match.extend(['embed_tokens', 'embed_in'])
246 |
247 | weight_to_save = get_mm_adapter_state_maybe_zero_3(self.model.named_parameters(), keys_to_match)
248 |
249 | if self.args.local_rank == 0 or self.args.local_rank == -1:
250 | self.model.config.save_pretrained(output_dir)
251 | torch.save(weight_to_save, os.path.join(output_dir, f'mm_projector.bin'))
252 | else:
253 | super(SetokimTrainer, self)._save_checkpoint(model, trial, metrics)
254 |
255 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
256 | if getattr(self.args, 'tune_mm_mlp_adapter', False):
257 | pass
258 | else:
259 | super(SetokimTrainer, self)._save(output_dir, state_dict)
--------------------------------------------------------------------------------
/src/train/train_mem.py:
--------------------------------------------------------------------------------
1 |
2 | if __name__ == "__main__":
3 | # training setokim
4 | # from train_setokim import train
5 | # train(attn_implementation="flash_attention_2")
6 |
7 | # training setok
8 | from train_setok import train
9 | train(attn_implementation="flash_attention_2")
10 |
--------------------------------------------------------------------------------
/src/train/train_setok.py:
--------------------------------------------------------------------------------
1 | # Adopted from https://github.com/lm-sys/FastChat. Below is the original copyright:
2 | # Adopted from tatsu-lab@stanford_alpaca. Below is the original copyright:
3 | # Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import os
18 | import copy
19 | from dataclasses import dataclass, field
20 | import json
21 | import logging
22 | import pathlib
23 | from typing import Dict, Optional, Sequence, List
24 | import torch
25 | import transformers
26 | import tokenizers
27 | from torch.utils.data import Dataset
28 | from src.model import SeTok
29 | from src.dataset import *
30 | from setok_trainer import SetokTrainer
31 | from training_utils import *
32 |
33 |
34 |
35 | def rank0_print(*args):
36 | if local_rank == 0:
37 | print(*args)
38 |
39 |
40 | from packaging import version
41 | IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse('0.14')
42 |
43 |
44 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
45 | """Collects the state dict and dump to disk."""
46 |
47 | trainer.model.config.save_pretrained(output_dir)
48 |
49 | current_folder = output_dir.split('/')[-1]
50 | parent_folder = os.path.dirname(output_dir)
51 | if trainer.args.local_rank == 0 or trainer.args.local_rank == -1:
52 | if current_folder.startswith('checkpoint-'):
53 | mm_projector_folder = os.path.join(parent_folder, "mm_projector")
54 | os.makedirs(mm_projector_folder, exist_ok=True)
55 | torch.save(trainer.model.state_dict(), os.path.join(mm_projector_folder, f'{current_folder}.bin'))
56 | else:
57 | torch.save(trainer.model.state_dict(), os.path.join(output_dir, f'mm_projector.bin'))
58 | return
59 |
60 |
61 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
62 | data_args,
63 | siglipTokenizer: Optional[transformers.PreTrainedTokenizer]=None,
64 | ) -> Dict:
65 | """Make dataset and collator for supervised fine-tuning."""
66 | train_dataset = TextImagePairDataset(
67 | tokenizer=tokenizer,
68 | data_path=data_args.data_path,
69 | data_args=data_args,
70 | constrative_tokenizer=siglipTokenizer
71 | )
72 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, siglipTokenizer=siglipTokenizer)
73 | return dict(train_dataset=train_dataset,
74 | eval_dataset=None,
75 | data_collator=data_collator)
76 |
77 |
78 |
79 | def train(attn_implementation=None):
80 | global local_rank
81 | parser = transformers.HfArgumentParser(
82 | (ModelArguments, DataArguments, TrainingArguments, VisionTowerArguments, VisionGeneratorArguments, ReconstructionLossArguments, ConstrastiveLossArguments))
83 | model_args, data_args, training_args, vision_tower_args, vision_generator_args, rec_loss_args, constrative_loss_args = parser.parse_args_into_dataclasses()
84 | local_rank = training_args.local_rank
85 | compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
86 |
87 | print('model_args:', model_args)
88 | print('data_args: ', data_args)
89 |
90 | model = SeTok(tokenizer_config=VisionTowerArguments,
91 | detokenizer_config=VisionGeneratorArguments,
92 | rec_loss_config=rec_loss_args,
93 | contrastive_loss_config=constrative_loss_args)
94 |
95 | tokenizer = transformers.AutoTokenizer.from_pretrained(
96 | model_args.model_name_or_path,
97 | cache_dir=training_args.cache_dir,
98 | model_max_length=training_args.model_max_length,
99 | padding_side="right",
100 | use_fast=False,
101 | )
102 |
103 | siglipTokenizer = transformers.AutoTokenizer.from_pretrained(
104 | vision_tower_args.vision_tower,
105 | cache_dir=training_args.cache_dir,
106 | padding_side="right",
107 | use_fast=False,
108 | )
109 |
110 |
111 | data_module = make_supervised_data_module(tokenizer=tokenizer,
112 | data_args=data_args,
113 | siglipTokenizer=siglipTokenizer)
114 | trainer = SetokTrainer(model=model,
115 | tokenizer=tokenizer,
116 | args=training_args,
117 | **data_module)
118 |
119 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
120 | trainer.train(resume_from_checkpoint=True)
121 | else:
122 | trainer.train()
123 | trainer.save_state()
124 | model.config.use_cache = True
125 | safe_save_model_for_hf_trainer(trainer=trainer,
126 | output_dir=training_args.output_dir)
127 |
128 |
129 | if __name__ == "__main__":
130 | train()
131 |
--------------------------------------------------------------------------------
/src/train/train_xformers.py:
--------------------------------------------------------------------------------
1 | # Make it more memory efficient by monkey patching the LLaMA model with xformers attention.
2 |
3 | # Need to call this before importing transformers.
4 | from src.train.llama_xformers_attn_monkey_patch import (
5 | replace_llama_attn_with_xformers_attn,
6 | )
7 |
8 | replace_llama_attn_with_xformers_attn()
9 |
10 | from src.train.train import train
11 |
12 | if __name__ == "__main__":
13 | train()
14 |
--------------------------------------------------------------------------------
/src/train/training_utils.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Dict, Optional, Sequence, List, Union
3 | import transformers
4 |
5 |
6 | @dataclass
7 | class ModelArguments:
8 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
9 | vision_tokenizer: Optional[str] = field(default='setok')
10 | vision_generator: Optional[str] = field(default='setok')
11 | diff_loss: Optional[str] = field(default='diff_loss')
12 | version: Optional[str] = field(default="v0")
13 | freeze_backbone: bool = field(default=False)
14 | mm_use_im_start_end: bool = field(default=True)
15 | mm_use_im_patch_token: bool = field(default=True)
16 | mm_patch_merge_type: Optional[str] = field(default='flat')
17 |
18 |
19 | @dataclass
20 | class VisionTowerArguments:
21 | vision_tower: Optional[str] = field(default='google/siglip-so400m-patch14-384')
22 | pretrain_vision_tokenizer: Optional[str] = field(default='')
23 | unfreeze_mm_vision_tower: bool = field(default=False)
24 | mm_vision_select_feature: Optional[str] = field(default="patch")
25 | mm_vision_select_layer: Optional[int] = field(default=-1)
26 | delay_load: bool = field(default=False)
27 | hidden_dim: Optional[int] = field(default=4096)
28 | token_feat_dim: Optional[int] = field(default=4096)
29 | min_cluster_num: Optional[int] = field(default=64)
30 | threshold: Optional[float] = field(default=0.55)
31 | nheads: Optional[int] = field(default=2)
32 | dim_feedforward: Optional[int] = field(default=4096)
33 | proj_drop: Optional[float] = field(default=0.2)
34 | attn_drop: Optional[float] = field(default=0.0)
35 | inner_cluster_layers: Optional[int] = field(default=2)
36 | intra_cluster_layers: Optional[int] = field(default=2)
37 |
38 | @dataclass
39 | class VisionInProjectionArguments:
40 | pretrain_mm_in_mlp_adapter: Optional[str] = field(default=None)
41 | mm_in_projector_type: Optional[str] = field(default='mlp')
42 | mm_hidden_size: Optional[int] = field(default=1052)
43 | hidden_size: Optional[int] = field(default=4096)
44 |
45 | @dataclass
46 | class VisionGeneratorArguments:
47 | pretrain_vision_detokenizer: Optional[str] = field(default='')
48 | patch_size: Optional[int] = field(default=14)
49 | out_image_size: Optional[int] = field(default=384)
50 | decoder_embed_dim: Optional[int] = field(default=4096)
51 | decoder_nheads: Optional[int] = field(default=8)
52 | decoder_depth: Optional[int] = field(default=16)
53 | mlp_ratio: Optional[int] = field(default=4.0)
54 | feature_mapper_path_or_name: Optional[str] = field(default="bert-base-uncased")
55 | num_hidden_layers: Optional[int] = field(default=6)
56 | cross_attention_freq: Optional[int] = field(default=2)
57 | initializer_range: Optional[float] = field(default=0.02)
58 |
59 | @dataclass
60 | class VisionOutProjectionArguments:
61 | pretrain_mm_out_mlp_adapter: Optional[str] = field(default=None)
62 | mm_out_projector_type: Optional[str] = field(default='mlp')
63 |
64 |
65 | @dataclass
66 | class ReconstructionLossArguments:
67 | disc_in_channels: Optional[int] = field(default=16)
68 | disc_num_layers: Optional[int] = field(default=2)
69 | disc_start: Optional[int] = field(default=5000)
70 | warm_up_end: Optional[int] = field(default=200)
71 |
72 | @dataclass
73 | class ConstrastiveLossArguments:
74 | text_encoder: Optional[str] = field(default='google/siglip-so400m-patch14-384')
75 | contrast_temperature: Optional[float] = field(default=0.07)
76 | multi_label: Optional[int] = field(default=0)
77 | share_temperature: bool = field(default=False)
78 | multi_label_loss_weight: Optional[float] = field(default=1.0)
79 |
80 | @dataclass
81 | class DiffLossArguments:
82 | diffloss_w: Optional[int] = field(default=3)
83 | diffloss_d: Optional[int] = field(default=1024)
84 | num_sampling_steps: Optional[str] = field(default='100')
85 | grad_checkpointing: bool = field(default=False)
86 | diffusion_batch_mul: Optional[int] = field(default=4)
87 | mask_ratio_min: Optional[float] = field(default=0.7)
88 |
89 |
90 | @dataclass
91 | class DataArguments:
92 | data_path: Union[List[str], str] = field(default=None, metadata={"help": "Path to the training data."})
93 | dataset_name: Union[List[str], str] = field(default=None)
94 | data_multiple: Union[List[float], str] = field(default=None, metadata={"help": "Data mutliplier for each dataset when mixed. None means direct concat."})
95 | lazy_preprocess: bool = False
96 | is_multimodal: bool = False
97 | image_folder: Union[List[str], str] = field(default=None)
98 | image_size: Optional[int] = field(default=448)
99 | image_aspect_ratio: str = 'square'
100 | task_type: Optional[str] = 'instruction'
101 |
102 |
103 | @dataclass
104 | class TrainingArguments(transformers.TrainingArguments):
105 | cache_dir: Optional[str] = field(default=None)
106 | optim: str = field(default="adamw_torch")
107 | remove_unused_columns: bool = field(default=False)
108 | tune_mm_in_mlp_adapter: bool = field(default=False)
109 | tune_mm_out_mlp_adapter: bool = field(default=False)
110 | freeze_mm_in_mlp_adapter: bool = field(default=False)
111 | freeze_mm_out_mlp_adapter: bool = field(default=False)
112 | mpt_attn_impl: Optional[str] = field(default="triton")
113 | model_max_length: int = field(
114 | default=512,
115 | metadata={
116 | "help":
117 | "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
118 | },
119 | )
120 | double_quant: bool = field(
121 | default=True,
122 | metadata={"help": "Compress the quantization statistics through double quantization."}
123 | )
124 | quant_type: str = field(
125 | default="nf4",
126 | metadata={"help": "Quantization data type to use. Should be one of `fp4` or `nf4`."}
127 | )
128 | bits: int = field(
129 | default=16,
130 | metadata={"help": "How many bits to use."}
131 | )
132 | lora_enable: bool = False
133 | lora_r: int = 64
134 | lora_alpha: int = 16
135 | lora_dropout: float = 0.05
136 | lora_weight_path: str = ""
137 | lora_bias: str = "none"
138 | mm_in_projector_lr: Optional[float] = None
139 | mm_out_projector_lr: Optional[float] = None
140 | group_by_modality_length: bool = field(default=False)
141 |
142 |
143 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import logging.handlers
4 | import os
5 | import sys
6 | import math
7 | import random
8 |
9 | import requests
10 |
11 | from src.constants import LOGDIR
12 |
13 | server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
14 | moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN."
15 |
16 | handler = None
17 |
18 |
19 | def build_logger(logger_name, logger_filename):
20 | global handler
21 |
22 | formatter = logging.Formatter(
23 | fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
24 | datefmt="%Y-%m-%d %H:%M:%S",
25 | )
26 |
27 | # Set the format of root handlers
28 | if not logging.getLogger().handlers:
29 | logging.basicConfig(level=logging.INFO)
30 | logging.getLogger().handlers[0].setFormatter(formatter)
31 |
32 | # Redirect stdout and stderr to loggers
33 | stdout_logger = logging.getLogger("stdout")
34 | stdout_logger.setLevel(logging.INFO)
35 | sl = StreamToLogger(stdout_logger, logging.INFO)
36 | sys.stdout = sl
37 |
38 | stderr_logger = logging.getLogger("stderr")
39 | stderr_logger.setLevel(logging.ERROR)
40 | sl = StreamToLogger(stderr_logger, logging.ERROR)
41 | sys.stderr = sl
42 |
43 | # Get logger
44 | logger = logging.getLogger(logger_name)
45 | logger.setLevel(logging.INFO)
46 |
47 | # Add a file handler for all loggers
48 | if handler is None:
49 | os.makedirs(LOGDIR, exist_ok=True)
50 | filename = os.path.join(LOGDIR, logger_filename)
51 | handler = logging.handlers.TimedRotatingFileHandler(
52 | filename, when='D', utc=True, encoding='UTF-8')
53 | handler.setFormatter(formatter)
54 |
55 | for name, item in logging.root.manager.loggerDict.items():
56 | if isinstance(item, logging.Logger):
57 | item.addHandler(handler)
58 |
59 | return logger
60 |
61 |
62 | class StreamToLogger(object):
63 | """
64 | Fake file-like stream object that redirects writes to a logger instance.
65 | """
66 | def __init__(self, logger, log_level=logging.INFO):
67 | self.terminal = sys.stdout
68 | self.logger = logger
69 | self.log_level = log_level
70 | self.linebuf = ''
71 |
72 | def __getattr__(self, attr):
73 | return getattr(self.terminal, attr)
74 |
75 | def write(self, buf):
76 | temp_linebuf = self.linebuf + buf
77 | self.linebuf = ''
78 | for line in temp_linebuf.splitlines(True):
79 | # From the io.TextIOWrapper docs:
80 | # On output, if newline is None, any '\n' characters written
81 | # are translated to the system default line separator.
82 | # By default sys.stdout.write() expects '\n' newlines and then
83 | # translates them so this is still cross platform.
84 | if line[-1] == '\n':
85 | self.logger.log(self.log_level, line.rstrip())
86 | else:
87 | self.linebuf += line
88 |
89 | def flush(self):
90 | if self.linebuf != '':
91 | self.logger.log(self.log_level, self.linebuf.rstrip())
92 | self.linebuf = ''
93 |
94 |
95 | def disable_torch_init():
96 | """
97 | Disable the redundant torch default initialization to accelerate model creation.
98 | """
99 | import torch
100 | setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
101 | setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
102 |
103 |
104 | def violates_moderation(text):
105 | """
106 | Check whether the text violates OpenAI moderation API.
107 | """
108 | url = "https://api.openai.com/v1/moderations"
109 | headers = {"Content-Type": "application/json",
110 | "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]}
111 | text = text.replace("\n", "")
112 | data = "{" + '"input": ' + f'"{text}"' + "}"
113 | data = data.encode("utf-8")
114 | try:
115 | ret = requests.post(url, headers=headers, data=data, timeout=5)
116 | flagged = ret.json()["results"][0]["flagged"]
117 | except requests.exceptions.RequestException as e:
118 | flagged = False
119 | except KeyError as e:
120 | flagged = False
121 |
122 | return flagged
123 |
124 |
125 | def pretty_print_semaphore(semaphore):
126 | if semaphore is None:
127 | return "None"
128 | return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
129 |
130 |
131 |
132 | def extend_list(original_list, multiplier):
133 | # Calculate how many elements to replicate and how many to select randomly
134 | replicate_elements = math.floor(multiplier)
135 | random_elements = multiplier - replicate_elements
136 |
137 | # Replicate the list
138 | replicated_list = original_list * replicate_elements
139 |
140 | # Calculate how many elements to randomly select
141 | select_elements = math.ceil(len(original_list) * random_elements)
142 |
143 | # Randomly select elements and append to the replicated list
144 | for _ in range(select_elements):
145 | random_element = random.choice(original_list)
146 | replicated_list.append(random_element)
147 |
148 | return replicated_list
--------------------------------------------------------------------------------