├── README.md
├── assets
├── motivation.png
└── pipeline.png
├── environment.yaml
├── evaluation.py
├── get_results.py
├── train_dreambooth_quant.py
├── train_dreambooth_quant.sh
└── utils
├── intlora_mul.py
├── intlora_shift.py
├── quant_layer.py
├── quant_model.py
└── quant_utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # IntLoRA: Integral Low-rank Adaptation of Quantized Diffusion Models
2 |
3 | ### [[Paper](https://arxiv.org/pdf/2410.21759)]
4 |
5 | ### Tune your customized diffusion model on ONE 3090 GPU!
6 |
7 | [Hang Guo](https://csguoh.github.io/)1 | [Yawei Li](https://yaweili.bitbucket.io/)2 | [Tao Dai](https://scholar.google.com/citations?user=MqJNdaAAAAAJ&hl=zh-CN)3| [Shu-Tao Xia](https://scholar.google.com.hk/citations?user=koAXTXgAAAAJ&hl=zh-CN)1,4| [Luca Benini](https://scholar.google.com.hk/citations?user=8riq3sYAAAAJ&hl=zh-CN&oi=ao)2
8 |
9 | 1Tsinghua University, 2ETH Zurich, 3Shenzhen University, 4Pengcheng Laboratory
10 |
11 |
12 |
13 |
14 | :star: If our IntLoRA is helpful to your images or projects, please help star this repo. Thanks! :hugs:
15 |
16 |
17 | ## TL; DR
18 | > Our IntLoRA offers three key advantages: (i) for fine-tuning, the pre-trained weights are quantized, reducing memory usage; (ii) for storage, both pre-trained and low-rank weights are in INT which consumes less disk space; (iii) for inference, IntLoRA weights can be naturally merged into quantized pre-trained weights through efficient integer multiplication or bit-shifting, eliminating additional post-training quantization.
19 |
20 |
21 |
22 |
23 |
24 |
25 | ## 🔎 Overview framework
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 | ## ⚙️ Dependencies and Installation
34 |
35 | ### Step-1 Download and Environment
36 |
37 | ```
38 | ## git clone this repository
39 | git clone https://github.com/csguoh/IntLoRA.git
40 | cd ./IntLoRA
41 | ```
42 |
43 | ```
44 | # create a conda environment
45 | conda env create -f environment.yaml
46 | conda activate intlora
47 | ```
48 |
49 | ### Step-2 Preparation
50 |
51 | - This code repository contains [Dreambooth](https://arxiv.org/abs/2208.12242) fine-tuning using our IntloRA. One can download the subject driven generation datasets [here](https://github.com/google/dreambooth/tree/main/dataset).
52 |
53 | - You also need to download the pre-trained model weights which will be fine-tuned with our IntLoRA. Here, we use the [Stable Diffusion-1.5](https://huggingface.co/CompVis) as a example.
54 |
55 |
56 | ### Step-3 Fine-tuning!
57 |
58 | - The main file of the fine-tuning is defined in the `train_dreambooth_quant.py`. We have also give the off-the-shelf configuration bash file for you. Thus, one can directly train customized diffusion models with the following command.
59 |
60 | ```
61 | bash ./train_dreambooth_quant.sh
62 | ```
63 |
64 | - The following are some key parameters that you may want to modify.
65 |
66 | - `rank`: the inner rank of the LoRA adapter
67 | - `intlora`: one can choose 'MUL' to use our InrLoRA-MUL or 'SHIFT' to use our InrLoRA-SHIFT
68 | - `nbits`: the number of bits of the weight quantization bits
69 | - `use_activation_quant`: whether to use the activation quantization
70 | - `act_nbits`: the activation bits of the activation quantization
71 | - `gradient_checkpointing `: whether to use the gradient checking to further reduce the GPU cost.
72 |
73 | - After run the fine-tunning command above, you can find the generated results in the `./log_quant` file fold.
74 |
75 | ## 😄 Evaluation
76 | - After generate the images, you can test the quality of each generated image using the following command:
77 |
78 | ```
79 | python evaluation.py
80 | ```
81 |
82 | - It will generate a `.json` file which contains the IQA results of each subject. Then we can obtain the overall evaluation result with
83 |
84 | ```
85 | python get_results.py
86 | ```
87 |
88 |
89 |
90 | ## 🎓Citations
91 | If our code helps your research or work, please consider citing our paper.
92 | The following are BibTeX references:
93 |
94 |
95 | ```
96 | @article{guo2024intlora,
97 | title={IntLoRA: Integral Low-rank Adaptation of Quantized Diffusion Models},
98 | author={Guo, Hang and Li, Yawei and Dai, Tao and Xia, Shu-Tao and Benini, Luca},
99 | journal={arXiv preprint arXiv:2410.21759},
100 | year={2024}
101 | }
102 | ```
103 |
104 |
105 |
106 |
--------------------------------------------------------------------------------
/assets/motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/IntLoRA/65a8257a4311e0feab9b33477c9748d13d8ce17b/assets/motivation.png
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/IntLoRA/65a8257a4311e0feab9b33477c9748d13d8ce17b/assets/pipeline.png
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: intlora
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - blas=1.0=mkl
10 | - brotli-python=1.0.9=py39h6a678d5_7
11 | - bzip2=1.0.8=h7b6447c_0
12 | - ca-certificates=2024.7.2=h06a4308_0
13 | - cffi=1.16.0=py39h5eee18b_0
14 | - cryptography=41.0.7=py39hdda0065_0
15 | - cuda-cudart=11.8.89=0
16 | - cuda-cupti=11.8.87=0
17 | - cuda-libraries=11.8.0=0
18 | - cuda-nvrtc=11.8.89=0
19 | - cuda-nvtx=11.8.86=0
20 | - cuda-runtime=11.8.0=0
21 | - ffmpeg=4.3=hf484d3e_0
22 | - filelock=3.13.1=py39h06a4308_0
23 | - freetype=2.12.1=h4a9f257_0
24 | - giflib=5.2.1=h5eee18b_3
25 | - gmp=6.2.1=h295c915_3
26 | - gmpy2=2.1.2=py39heeb90bb_0
27 | - gnutls=3.6.15=he1e5248_0
28 | - idna=3.4=py39h06a4308_0
29 | - intel-openmp=2023.1.0=hdb19cb5_46306
30 | - jinja2=3.1.2=py39h06a4308_0
31 | - jpeg=9e=h5eee18b_1
32 | - lame=3.100=h7b6447c_0
33 | - lcms2=2.12=h3be6417_0
34 | - ld_impl_linux-64=2.38=h1181459_1
35 | - lerc=3.0=h295c915_0
36 | - libcublas=11.11.3.6=0
37 | - libcufft=10.9.0.58=0
38 | - libcufile=1.8.1.2=0
39 | - libcurand=10.3.4.101=0
40 | - libcusolver=11.4.1.48=0
41 | - libcusparse=11.7.5.86=0
42 | - libdeflate=1.17=h5eee18b_1
43 | - libffi=3.4.4=h6a678d5_0
44 | - libgcc-ng=11.2.0=h1234567_1
45 | - libgomp=11.2.0=h1234567_1
46 | - libiconv=1.16=h7f8727e_2
47 | - libidn2=2.3.4=h5eee18b_0
48 | - libjpeg-turbo=2.0.0=h9bf148f_0
49 | - libnpp=11.8.0.86=0
50 | - libnvjpeg=11.9.0.86=0
51 | - libpng=1.6.39=h5eee18b_0
52 | - libstdcxx-ng=11.2.0=h1234567_1
53 | - libtasn1=4.19.0=h5eee18b_0
54 | - libtiff=4.5.1=h6a678d5_0
55 | - libunistring=0.9.10=h27cfd23_0
56 | - libwebp=1.3.2=h11a3e52_0
57 | - libwebp-base=1.3.2=h5eee18b_0
58 | - llvm-openmp=14.0.6=h9e868ea_0
59 | - lz4-c=1.9.4=h6a678d5_0
60 | - mkl=2023.1.0=h213fc3f_46344
61 | - mkl-service=2.4.0=py39h5eee18b_1
62 | - mkl_fft=1.3.8=py39h5eee18b_0
63 | - mkl_random=1.2.4=py39hdb19cb5_0
64 | - mpc=1.1.0=h10f8cd9_1
65 | - mpfr=4.0.2=hb69a4c5_1
66 | - mpmath=1.3.0=py39h06a4308_0
67 | - ncurses=6.4=h6a678d5_0
68 | - nettle=3.7.3=hbbd107a_1
69 | - numpy-base=1.26.2=py39hb5e798b_0
70 | - openh264=2.1.1=h4ff587b_0
71 | - openjpeg=2.4.0=h3ad879b_0
72 | - openssl=3.0.14=h5eee18b_0
73 | - pip=23.3=py39h06a4308_0
74 | - pycparser=2.21=pyhd3eb1b0_0
75 | - pyopenssl=23.2.0=py39h06a4308_0
76 | - pysocks=1.7.1=py39h06a4308_0
77 | - python=3.9.18=h955ad1f_0
78 | - pytorch=2.1.0=py3.9_cuda11.8_cudnn8.7.0_0
79 | - pytorch-cuda=11.8=h7e8668a_5
80 | - pytorch-mutex=1.0=cuda
81 | - pyyaml=6.0.1=py39h5eee18b_0
82 | - readline=8.2=h5eee18b_0
83 | - requests=2.31.0=py39h06a4308_0
84 | - setuptools=68.0.0=py39h06a4308_0
85 | - sqlite=3.41.2=h5eee18b_0
86 | - tbb=2021.8.0=hdb19cb5_0
87 | - tk=8.6.12=h1ccaba5_0
88 | - torchaudio=2.1.0=py39_cu118
89 | - torchtriton=2.1.0=py39
90 | - typing_extensions=4.7.1=py39h06a4308_0
91 | - wheel=0.41.2=py39h06a4308_0
92 | - xz=5.4.2=h5eee18b_0
93 | - yaml=0.2.5=h7b6447c_0
94 | - zlib=1.2.13=h5eee18b_0
95 | - zstd=1.5.5=hc292b87_0
96 | - pip:
97 | - absl-py==2.1.0
98 | - accelerate==0.25.0
99 | - addict==2.4.0
100 | - aiofiles==23.2.1
101 | - aiohttp==3.8.6
102 | - aiosignal==1.3.1
103 | - annotated-types==0.7.0
104 | - antlr4-python3-runtime==4.9.3
105 | - anyio==3.7.1
106 | - asttokens==2.4.1
107 | - async-timeout==4.0.3
108 | - attrs==23.1.0
109 | - axial-positional-embedding==0.2.1
110 | - bitsandbytes==0.43.1
111 | - cachetools==5.3.2
112 | - certifi==2023.7.22
113 | - chardet==5.2.0
114 | - charset-normalizer==3.3.1
115 | - click==8.1.7
116 | - colorama==0.4.6
117 | - contourpy==1.1.1
118 | - cycler==0.12.1
119 | - dataclasses-json==0.6.1
120 | - deprecated==1.2.14
121 | - diffusers==0.26.0
122 | - docker-pycreds==0.4.0
123 | - einops==0.4.0
124 | - exceptiongroup==1.1.3
125 | - executing==2.0.1
126 | - facexlib==0.3.0
127 | - fastapi==0.112.2
128 | - ffmpy==0.4.0
129 | - filterpy==1.4.5
130 | - fonttools==4.43.1
131 | - frozenlist==1.4.0
132 | - fsspec==2023.10.0
133 | - ftfy==6.2.3
134 | - future==1.0.0
135 | - gitdb==4.0.11
136 | - gitpython==3.1.43
137 | - google-auth==2.26.2
138 | - google-auth-oauthlib==1.2.0
139 | - gradio==4.42.0
140 | - gradio-client==1.3.0
141 | - greenlet==3.0.1
142 | - grpcio==1.59.2
143 | - grpcio-tools==1.59.2
144 | - h11==0.14.0
145 | - h2==4.1.0
146 | - hpack==4.0.0
147 | - httpcore==0.18.0
148 | - httpx==0.25.0
149 | - huggingface-hub==0.24.0
150 | - hyperframe==6.0.1
151 | - icecream==2.1.3
152 | - imageio==2.35.1
153 | - imgaug==0.4.0
154 | - importlib-metadata==7.0.1
155 | - importlib-resources==6.1.0
156 | - joblib==1.3.2
157 | - jsonpatch==1.33
158 | - jsonpointer==2.4
159 | - kiwisolver==1.4.5
160 | - langchain==0.0.327
161 | - langsmith==0.0.54
162 | - lazy-loader==0.4
163 | - lightning-utilities==0.11.6
164 | - llvmlite==0.41.1
165 | - lmdb==1.5.1
166 | - local-attention==1.4.4
167 | - lpips==0.1.4
168 | - markdown==3.5.2
169 | - markdown-it-py==3.0.0
170 | - markupsafe==2.1.3
171 | - marshmallow==3.20.1
172 | - matplotlib==3.7.0
173 | - mdurl==0.1.2
174 | - multidict==6.0.4
175 | - mypy-extensions==1.0.0
176 | - networkx==3.2.1
177 | - nltk==3.8.1
178 | - numba==0.58.1
179 | - numpy==1.23.5
180 | - nvidia-cublas-cu11==11.10.3.66
181 | - nvidia-cublas-cu12==12.1.3.1
182 | - nvidia-cuda-cupti-cu12==12.1.105
183 | - nvidia-cuda-nvrtc-cu11==11.7.99
184 | - nvidia-cuda-nvrtc-cu12==12.1.105
185 | - nvidia-cuda-runtime-cu11==11.7.99
186 | - nvidia-cuda-runtime-cu12==12.1.105
187 | - nvidia-cudnn-cu11==8.5.0.96
188 | - nvidia-cudnn-cu12==8.9.2.26
189 | - nvidia-cufft-cu12==11.0.2.54
190 | - nvidia-curand-cu12==10.3.2.106
191 | - nvidia-cusolver-cu12==11.4.5.107
192 | - nvidia-cusparse-cu12==12.1.0.106
193 | - nvidia-nccl-cu12==2.18.1
194 | - nvidia-nvjitlink-cu12==12.3.52
195 | - nvidia-nvtx-cu12==12.1.105
196 | - oauthlib==3.2.2
197 | - omegaconf==2.3.0
198 | - open-clip-torch==2.26.1
199 | - openai==0.28.1
200 | - openai-clip==1.0.1
201 | - opencv-python==4.9.0.80
202 | - opencv-python-headless==4.10.0.84
203 | - orjson==3.10.7
204 | - packaging==23.2
205 | - pandas==1.5.3
206 | - patool==1.12
207 | - patsy==0.5.4
208 | - pdfminer==20191125
209 | - peft==0.7.1
210 | - pillow==10.1.0
211 | - platformdirs==4.2.2
212 | - product-key-memory==0.1.10
213 | - protobuf==4.23.4
214 | - psutil==5.9.6
215 | - pyarrow==13.0.0
216 | - pyasn1==0.5.1
217 | - pyasn1-modules==0.3.0
218 | - pycryptodome==3.19.0
219 | - pydantic==2.8.2
220 | - pydantic-core==2.20.1
221 | - pydub==0.25.1
222 | - pygments==2.18.0
223 | - pyiqa==0.1.10
224 | - pyparsing==3.1.1
225 | - python-dateutil==2.8.2
226 | - python-dotenv==1.0.0
227 | - python-multipart==0.0.9
228 | - pytorch-lightning==2.1.0
229 | - pytz==2023.3.post1
230 | - qdrant-client==1.1.1
231 | - reformer-pytorch==1.4.4
232 | - regex==2023.10.3
233 | - requests-oauthlib==1.3.1
234 | - rich==13.7.1
235 | - rsa==4.9
236 | - ruff==0.6.2
237 | - safetensors==0.4.4
238 | - scikit-base==0.6.1
239 | - scikit-image==0.24.0
240 | - scikit-learn==1.2.2
241 | - scipy==1.10.1
242 | - seaborn==0.13.2
243 | - semantic-version==2.10.0
244 | - sentence-transformers==2.2.2
245 | - sentencepiece==0.1.99
246 | - sentry-sdk==2.10.0
247 | - setproctitle==1.3.3
248 | - shapely==2.0.6
249 | - shellingham==1.5.4
250 | - six==1.16.0
251 | - sktime==0.16.1
252 | - smmap==5.0.1
253 | - sniffio==1.3.0
254 | - sqlalchemy==2.0.22
255 | - starlette==0.38.2
256 | - statsmodels==0.14.1
257 | - sympy==1.11.1
258 | - tenacity==8.2.3
259 | - tensorboard==2.15.1
260 | - tensorboard-data-server==0.7.2
261 | - threadpoolctl==3.2.0
262 | - tifffile==2024.8.24
263 | - timm==0.9.7
264 | - tokenizers==0.15.2
265 | - tomli==2.0.1
266 | - tomlkit==0.12.0
267 | - torch==1.13.0
268 | - torchmetrics==1.4.1
269 | - torchvision==0.16.0
270 | - tqdm==4.64.1
271 | - transformers==4.37.2
272 | - triton==2.1.0
273 | - typer==0.12.5
274 | - typing-extensions==4.12.2
275 | - typing-inspect==0.8.0
276 | - tzdata==2023.3
277 | - urllib3==2.2.2
278 | - uvicorn==0.21.1
279 | - wandb==0.17.5
280 | - wcwidth==0.2.13
281 | - websockets==12.0
282 | - werkzeug==3.0.1
283 | - wrapt==1.16.0
284 | - yapf==0.40.2
285 | - yarl==1.9.2
286 | - zipp==3.17.0
287 |
288 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 |
16 | import argparse
17 | import hashlib
18 | import logging
19 | import math
20 | import os
21 | import warnings
22 | from pathlib import Path
23 |
24 | from functools import reduce
25 | import numpy as np
26 | import torch
27 | import torch.nn.functional as F
28 | import torch.utils.checkpoint
29 | import transformers
30 | from packaging import version
31 | from PIL import Image
32 | from torch.utils.data import Dataset, DataLoader
33 | from torchvision import transforms
34 | from tqdm.auto import tqdm
35 | from transformers import AutoTokenizer, PretrainedConfig, ViTFeatureExtractor, ViTModel
36 |
37 | import lpips
38 | import json
39 | from PIL import Image
40 | import requests
41 | from transformers import AutoProcessor, AutoTokenizer, CLIPModel
42 | import torchvision.transforms.functional as TF
43 | from torch.nn.functional import cosine_similarity
44 | from torchvision.transforms import Compose, ToTensor, Normalize, Resize, ToPILImage
45 |
46 |
47 |
48 | def get_prompt(subject_name, prompt_idx):
49 | subject_names = [
50 | "backpack", "backpack_dog", "bear_plushie", "berry_bowl", "can",
51 | "candle", "cat", "cat2", "clock", "colorful_sneaker",
52 | "dog", "dog2", "dog3", "dog5", "dog6",
53 | "dog7", "dog8", "duck_toy", "fancy_boot", "grey_sloth_plushie",
54 | "monster_toy", "pink_sunglasses", "poop_emoji", "rc_car", "red_cartoon",
55 | "robot_toy", "shiny_sneaker", "teapot", "vase", "wolf_plushie",
56 | ]
57 |
58 | class_tokens = [
59 | "backpack", "backpack", "stuffed animal", "bowl", "can",
60 | "candle", "cat", "cat", "clock", "sneaker",
61 | "dog", "dog", "dog", "dog", "dog",
62 | "dog", "dog", "toy", "boot", "stuffed animal",
63 | "toy", "glasses", "toy", "toy", "cartoon",
64 | "toy", "sneaker", "teapot", "vase", "stuffed animal",
65 | ]
66 |
67 | class_token = class_tokens[subject_names.index(subject_name)]
68 |
69 | prompt_list = [
70 | f"a qwe {class_token} in the jungle",
71 | f"a qwe {class_token} in the snow",
72 | f"a qwe {class_token} on the beach",
73 | f"a qwe {class_token} on a cobblestone street",
74 | f"a qwe {class_token} on top of pink fabric",
75 | f"a qwe {class_token} on top of a wooden floor",
76 | f"a qwe {class_token} with a city in the background",
77 | f"a qwe {class_token} with a mountain in the background",
78 | f"a qwe {class_token} with a blue house in the background",
79 | f"a qwe {class_token} on top of a purple rug in a forest",
80 | f"a qwe {class_token} wearing a red hat",
81 | f"a qwe {class_token} wearing a santa hat",
82 | f"a qwe {class_token} wearing a rainbow scarf",
83 | f"a qwe {class_token} wearing a black top hat and a monocle",
84 | f"a qwe {class_token} in a chef outfit",
85 | f"a qwe {class_token} in a firefighter outfit",
86 | f"a qwe {class_token} in a police outfit",
87 | f"a qwe {class_token} wearing pink glasses",
88 | f"a qwe {class_token} wearing a yellow shirt",
89 | f"a qwe {class_token} in a purple wizard outfit",
90 | f"a red qwe {class_token}",
91 | f"a purple qwe {class_token}",
92 | f"a shiny qwe {class_token}",
93 | f"a wet qwe {class_token}",
94 | f"a cube shaped qwe {class_token}",
95 | ]
96 |
97 | return prompt_list[int(prompt_idx)]
98 |
99 |
100 | class PromptDatasetCLIP(Dataset):
101 | def __init__(self, subject_name, data_dir_B, tokenizer, processor, epoch=None):
102 | self.data_dir_B = data_dir_B
103 |
104 | subject_name, prompt_idx = subject_name.split('-')
105 |
106 | data_dir_B = os.path.join(self.data_dir_B, str(epoch))
107 | self.image_lst = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")]
108 | self.prompt_lst = [get_prompt(subject_name, prompt_idx)] * len(self.image_lst)
109 |
110 | self.tokenizer = tokenizer
111 | self.processor = processor
112 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
113 |
114 | def __len__(self):
115 | return len(self.image_lst)
116 |
117 | def __getitem__(self, idx):
118 | image_path = self.image_lst[idx]
119 | image = Image.open(image_path)
120 | prompt = self.prompt_lst[idx]
121 |
122 | extrema = image.getextrema()
123 | if all(min_val == max_val == 0 for min_val, max_val in extrema):
124 | return None, None
125 | else:
126 | prompt_inputs = self.tokenizer([prompt], padding=True, return_tensors="pt")
127 | image_inputs = self.processor(images=image, return_tensors="pt")
128 |
129 | return image_inputs, prompt_inputs
130 |
131 |
132 | class PairwiseImageDatasetCLIP(Dataset):
133 | def __init__(self, subject_name, data_dir_A, data_dir_B, processor, epoch):
134 | self.data_dir_A = data_dir_A
135 | self.data_dir_B = data_dir_B
136 |
137 | subject_name, prompt_idx = subject_name.split('-')
138 |
139 | self.data_dir_A = os.path.join(self.data_dir_A, subject_name)
140 | self.image_files_A = [os.path.join(self.data_dir_A, f) for f in os.listdir(self.data_dir_A) if
141 | f.endswith(".jpg")]
142 |
143 | data_dir_B = os.path.join(self.data_dir_B, str(epoch))
144 | self.image_files_B = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")]
145 |
146 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147 | self.processor = processor
148 |
149 | def __len__(self):
150 | return len(self.image_files_A) * len(self.image_files_B)
151 |
152 | def __getitem__(self, index):
153 | index_A = index // len(self.image_files_B)
154 | index_B = index % len(self.image_files_B)
155 |
156 | image_A = Image.open(self.image_files_A[index_A]) # .convert("RGB")
157 | image_B = Image.open(self.image_files_B[index_B]) # .convert("RGB")
158 |
159 | extrema_A = image_A.getextrema()
160 | extrema_B = image_B.getextrema()
161 | if all(min_val == max_val == 0 for min_val, max_val in extrema_A) or all(
162 | min_val == max_val == 0 for min_val, max_val in extrema_B):
163 | return None, None
164 | else:
165 | inputs_A = self.processor(images=image_A, return_tensors="pt")
166 | inputs_B = self.processor(images=image_B, return_tensors="pt")
167 |
168 | return inputs_A, inputs_B
169 |
170 |
171 | class PairwiseImageDatasetDINO(Dataset):
172 | def __init__(self, subject_name, data_dir_A, data_dir_B, feature_extractor, epoch):
173 | self.data_dir_A = data_dir_A
174 | self.data_dir_B = data_dir_B
175 |
176 | subject_name, prompt_idx = subject_name.split('-')
177 |
178 | self.data_dir_A = os.path.join(self.data_dir_A, subject_name)
179 | self.image_files_A = [os.path.join(self.data_dir_A, f) for f in os.listdir(self.data_dir_A) if
180 | f.endswith(".jpg")]
181 |
182 | data_dir_B = os.path.join(self.data_dir_B, str(epoch))
183 | self.image_files_B = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")]
184 |
185 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
186 | self.feature_extractor = feature_extractor
187 |
188 | def __len__(self):
189 | return len(self.image_files_A) * len(self.image_files_B)
190 |
191 | def __getitem__(self, index):
192 | index_A = index // len(self.image_files_B)
193 | index_B = index % len(self.image_files_B)
194 |
195 | image_A = Image.open(self.image_files_A[index_A]) # .convert("RGB")
196 | image_B = Image.open(self.image_files_B[index_B]) # .convert("RGB")
197 |
198 | extrema_A = image_A.getextrema()
199 | extrema_B = image_B.getextrema()
200 | if all(min_val == max_val == 0 for min_val, max_val in extrema_A) or all(
201 | min_val == max_val == 0 for min_val, max_val in extrema_B):
202 | return None, None
203 | else:
204 | inputs_A = self.feature_extractor(images=image_A, return_tensors="pt")
205 | inputs_B = self.feature_extractor(images=image_B, return_tensors="pt")
206 |
207 | return inputs_A, inputs_B
208 |
209 |
210 | class PairwiseImageDatasetLPIPS(Dataset):
211 | def __init__(self, subject_name, data_dir_A, data_dir_B, epoch):
212 | self.data_dir_A = data_dir_A
213 | self.data_dir_B = data_dir_B
214 |
215 | subject_name, prompt_idx = subject_name.split('-')
216 |
217 | self.data_dir_A = os.path.join(self.data_dir_A, subject_name)
218 | self.image_files_A = [os.path.join(self.data_dir_A, f) for f in os.listdir(self.data_dir_A) if
219 | f.endswith(".jpg")]
220 |
221 | data_dir_B = os.path.join(self.data_dir_B, str(epoch))
222 | self.image_files_B = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")]
223 |
224 | self.transform = Compose([
225 | Resize((512, 512)),
226 | ToTensor(),
227 | Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
228 | ])
229 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
230 |
231 | def __len__(self):
232 | return len(self.image_files_A) * len(self.image_files_B)
233 |
234 | def __getitem__(self, index):
235 | index_A = index // len(self.image_files_B)
236 | index_B = index % len(self.image_files_B)
237 |
238 | image_A = Image.open(self.image_files_A[index_A]) # .convert("RGB")
239 | image_B = Image.open(self.image_files_B[index_B]) # .convert("RGB")
240 |
241 | extrema_A = image_A.getextrema()
242 | extrema_B = image_B.getextrema()
243 | if all(min_val == max_val == 0 for min_val, max_val in extrema_A) or all(
244 | min_val == max_val == 0 for min_val, max_val in extrema_B):
245 | return None, None
246 | else:
247 | if self.transform:
248 | image_A = self.transform(image_A)
249 | image_B = self.transform(image_B)
250 |
251 | return image_A, image_B
252 |
253 |
254 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
255 | clip_text_model = CLIPModel.from_pretrained("/data2/xxx/pretrained/clip-vit-large-patch14").to(device)
256 | clip_text_tokenizer = AutoTokenizer.from_pretrained("/data2/xxx/pretrained/clip-vit-large-patch14")
257 | clip_text_processor = AutoProcessor.from_pretrained("/data2/xxx/pretrained/clip-vit-large-patch14")
258 |
259 | def clip_text(subject_name, image_dir):
260 | criterion = 'clip_text'
261 |
262 | model = clip_text_model
263 | # Get the text features
264 | tokenizer = clip_text_tokenizer
265 | # Get the image features
266 | processor =clip_text_processor
267 |
268 | epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)])
269 | best_mean_similarity = 0
270 | mean_similarity_list = []
271 | for epoch in epochs:
272 | similarity = []
273 | dataset = PromptDatasetCLIP(subject_name, image_dir, tokenizer, processor, epoch)
274 | dataloader = DataLoader(dataset, batch_size=32)
275 | for i in range(len(dataset)):
276 | image_inputs, prompt_inputs = dataset[i]
277 | if image_inputs is not None and prompt_inputs is not None:
278 | image_inputs['pixel_values'] = image_inputs['pixel_values'].to(device)
279 | prompt_inputs['input_ids'] = prompt_inputs['input_ids'].to(device)
280 | prompt_inputs['attention_mask'] = prompt_inputs['attention_mask'].to(device)
281 | # print(prompt_inputs)
282 | image_features = model.get_image_features(**image_inputs)
283 | text_features = model.get_text_features(**prompt_inputs)
284 |
285 | sim = cosine_similarity(image_features, text_features)
286 |
287 | # image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True)
288 | # text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True)
289 | # logit_scale = model.logit_scale.exp()
290 | # sim = torch.matmul(text_features, image_features.t()) * logit_scale
291 | similarity.append(sim.item())
292 |
293 | if similarity:
294 | mean_similarity = torch.tensor(similarity).mean().item()
295 | mean_similarity_list.append(mean_similarity)
296 | best_mean_similarity = max(best_mean_similarity, mean_similarity)
297 | print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {mean_similarity}({best_mean_similarity})')
298 | else:
299 | mean_similarity_list.append(0)
300 | print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {0}({best_mean_similarity})')
301 |
302 | return mean_similarity_list
303 |
304 |
305 | clip_image_model = CLIPModel.from_pretrained("/data2/xxx/pretrained/clip-vit-large-patch14").to(device)
306 | clip_image_processor = AutoProcessor.from_pretrained("/data2/xxx/pretrained/clip-vit-large-patch14")
307 |
308 | def clip_image(subject_name, image_dir, dreambooth_dir='/data2/xxx/dataset/dreambooth'):
309 | criterion = 'clip_image'
310 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
311 | model = clip_image_model
312 | # Get the image features
313 | processor = clip_image_processor
314 |
315 | epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)])
316 | best_mean_similarity = 0
317 | mean_similarity_list = []
318 | for epoch in epochs:
319 | similarity = []
320 | dataset = PairwiseImageDatasetCLIP(subject_name, dreambooth_dir, image_dir, processor, epoch)
321 | # dataset = SelfPairwiseImageDatasetCLIP(subject, './data', processor)
322 |
323 | for i in range(len(dataset)):
324 | inputs_A, inputs_B = dataset[i]
325 | if inputs_A is not None and inputs_B is not None:
326 | inputs_A['pixel_values'] = inputs_A['pixel_values'].to(device)
327 | inputs_B['pixel_values'] = inputs_B['pixel_values'].to(device)
328 |
329 | image_A_features = model.get_image_features(**inputs_A)
330 | image_B_features = model.get_image_features(**inputs_B)
331 |
332 | image_A_features = image_A_features / image_A_features.norm(p=2, dim=-1, keepdim=True)
333 | image_B_features = image_B_features / image_B_features.norm(p=2, dim=-1, keepdim=True)
334 |
335 | logit_scale = model.logit_scale.exp()
336 | sim = torch.matmul(image_A_features, image_B_features.t()) # * logit_scale
337 | similarity.append(sim.item())
338 |
339 | if similarity:
340 | mean_similarity = torch.tensor(similarity).mean().item()
341 | best_mean_similarity = max(best_mean_similarity, mean_similarity)
342 | mean_similarity_list.append(mean_similarity)
343 | print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {mean_similarity}({best_mean_similarity})')
344 | else:
345 | mean_similarity_list.append(0)
346 | print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {0}({best_mean_similarity})')
347 |
348 | return mean_similarity_list
349 |
350 |
351 | dino_model = ViTModel.from_pretrained('/data2/xxx/pretrained/dino-vits16').to(device)
352 | deno_feature_extractor = ViTFeatureExtractor.from_pretrained('/data2/xxx/pretrained/dino-vits16')
353 |
354 | def dino(subject_name, image_dir, dreambooth_dir='/data2/xxx/dataset/dreambooth'):
355 | criterion = 'dino'
356 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
357 | model = dino_model
358 | feature_extractor = deno_feature_extractor
359 |
360 | epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)])
361 | best_mean_similarity = 0
362 | mean_similarity_list = []
363 | for epoch in epochs:
364 | similarity = []
365 | # dataset = PairwiseImageDatasetDINO(subject, './data', image_dir, feature_extractor, epoch)
366 | dataset = PairwiseImageDatasetDINO(subject_name, dreambooth_dir, image_dir, feature_extractor, epoch)
367 | # dataset = SelfPairwiseImageDatasetDINO(subject, './data', feature_extractor)
368 |
369 | for i in range(len(dataset)):
370 | inputs_A, inputs_B = dataset[i]
371 | if inputs_A is not None and inputs_B is not None:
372 | inputs_A['pixel_values'] = inputs_A['pixel_values'].to(device)
373 | inputs_B['pixel_values'] = inputs_B['pixel_values'].to(device)
374 |
375 | outputs_A = model(**inputs_A)
376 | image_A_features = outputs_A.last_hidden_state[:, 0, :]
377 |
378 | outputs_B = model(**inputs_B)
379 | image_B_features = outputs_B.last_hidden_state[:, 0, :]
380 |
381 | image_A_features = image_A_features / image_A_features.norm(p=2, dim=-1, keepdim=True)
382 | image_B_features = image_B_features / image_B_features.norm(p=2, dim=-1, keepdim=True)
383 |
384 | sim = torch.matmul(image_A_features, image_B_features.t()) # * logit_scale
385 | similarity.append(sim.item())
386 |
387 | mean_similarity = torch.tensor(similarity).mean().item()
388 | best_mean_similarity = max(best_mean_similarity, mean_similarity)
389 | mean_similarity_list.append(mean_similarity)
390 | print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {mean_similarity}({best_mean_similarity})')
391 |
392 | return mean_similarity_list
393 |
394 |
395 | lpips_loss_fn = lpips.LPIPS(net='vgg').to(device)
396 |
397 | def lpips_image(subject_name, image_dir,dreambooth_dir='/data2/xxx/dataset/dreambooth'):
398 | criterion = 'lpips_image'
399 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
400 | # Set up the LPIPS model (vgg=True uses the VGG-based model from the paper)
401 |
402 | loss_fn = lpips_loss_fn
403 | # 有可能有些epoch没跑全
404 | epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)])
405 | mean_similarity_list = []
406 | best_mean_similarity = 0
407 | for epoch in epochs:
408 | similarity = []
409 | dataset = PairwiseImageDatasetLPIPS(subject_name, dreambooth_dir, image_dir, epoch)
410 | # dataset = SelfPairwiseImageDatasetLPIPS(subject, './data')
411 |
412 | for i in range(len(dataset)):
413 | image_A, image_B = dataset[i]
414 | if image_A is not None and image_B is not None:
415 | image_A = image_A.to(device)
416 | image_B = image_B.to(device)
417 |
418 | # Calculate LPIPS between the two images
419 | distance = loss_fn(image_A, image_B)
420 |
421 | similarity.append(distance.item())
422 |
423 | mean_similarity = torch.tensor(similarity).mean().item()
424 | best_mean_similarity = max(best_mean_similarity, mean_similarity)
425 | mean_similarity_list.append(mean_similarity)
426 | print(f'epoch: {epoch}, criterion: LPIPS distance, mean_similarity: {mean_similarity}({best_mean_similarity})')
427 |
428 | return mean_similarity_list
429 |
430 |
431 | if __name__ == "__main__":
432 | image_dir = '/data2/xxx/ControlNet/oft-db/log_quant'
433 |
434 | subject_dirs, subject_names = [], []
435 | for name in os.listdir(image_dir):
436 | if os.path.isdir(os.path.join(image_dir, name)):
437 | subject_dirs.append(os.path.join(image_dir, name))
438 | subject_names.append(name)
439 |
440 | results_path = '/data2/xxx/ControlNet/oft-db/log_quant/results.json'
441 |
442 | results_dict = dict()
443 | if os.path.exists(results_path):
444 | with open(results_path, 'r') as f:
445 | results = f.__iter__()
446 | while True:
447 | try:
448 | result_json = json.loads(next(results))
449 | results_dict.update(result_json)
450 |
451 | except StopIteration:
452 | print("finish extraction.")
453 | break
454 |
455 | for idx in range(len(subject_names)):
456 | subject_name = subject_names[idx]
457 | subject_dir = subject_dirs[idx]
458 | print(f'evaluating {subject_dir}')
459 | dino_sim = dino(subject_name, subject_dir)
460 | clip_i_sim = clip_image(subject_name, subject_dir)
461 | clip_t_sim = clip_text(subject_name, subject_dir)
462 | lpips_sim = lpips_image(subject_name, subject_dir)
463 |
464 | subject_result = {'DINO': dino_sim, 'CLIP-I': clip_i_sim, 'CLIP-T': clip_t_sim, 'LPIPS': lpips_sim}
465 | print(subject_result)
466 |
467 | with open(results_path, 'a') as f:
468 | json_string = json.dumps({subject_name: subject_result})
469 | f.write(json_string + "\n")
470 |
--------------------------------------------------------------------------------
/get_results.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 |
16 | import math
17 | import os
18 | from functools import reduce
19 | import numpy as np
20 |
21 | import json
22 |
23 | if __name__ == "__main__":
24 | results_path = '/data2/xxx/ControlNet/oft-db/log_quant/results.json'
25 | results_dict = dict()
26 | if os.path.exists(results_path):
27 | with open(results_path, 'r') as f:
28 | results = f.__iter__()
29 | while True:
30 | try:
31 | result_json = json.loads(next(results))
32 | results_dict.update(result_json)
33 | except StopIteration:
34 | print("finish extraction.")
35 | break
36 | else:
37 | raise NotImplementedError
38 | total_result = np.zeros(4)
39 | metric_name_list = ['DINO', 'CLIP-I', 'CLIP-T', 'LPIPS']
40 | except_list = []
41 | num_samples = 0
42 | print(len((results_dict.keys())))
43 | for subject_name, subject_results in results_dict.items():
44 | if subject_name in except_list:
45 | continue
46 | num_samples += 1
47 | metric_results_percent = None
48 | for metric_name, metric_results in subject_results.items():
49 | metric_results = [0 if np.isnan(r) else r for r in metric_results]
50 | try:
51 | metric_results_norm = np.array(metric_results) / (max(metric_results) - min(metric_results))
52 | except:
53 | print(subject_name)
54 | if metric_results_percent is None:
55 | metric_results_percent = metric_results_norm
56 | else:
57 | metric_results_percent += metric_results_norm
58 |
59 | subject_results_max_idx = np.argmax(metric_results_percent)
60 | for idx, metric_name in enumerate(metric_name_list):
61 | total_result[idx] += subject_results[metric_name][subject_results_max_idx]
62 | total_result /= num_samples
63 | print(f'DINO: {total_result[0]}, CLIP-I: {total_result[1]}, CLIP-T: {total_result[2]}, LPIPS: {total_result[3]}')
64 |
--------------------------------------------------------------------------------
/train_dreambooth_quant.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 |
16 | import argparse
17 | import hashlib
18 | import logging
19 | import math
20 | import os
21 | import warnings
22 | from pathlib import Path
23 | import torch.nn as nn
24 | import numpy as np
25 | import torch
26 | import torch.nn.functional as F
27 | import torch.utils.checkpoint
28 | import transformers
29 | from accelerate import Accelerator
30 | from accelerate.logging import get_logger
31 | from accelerate.utils import ProjectConfiguration, set_seed
32 | from huggingface_hub import create_repo, upload_folder
33 | from packaging import version
34 | from PIL import Image
35 | from torch.utils.data import Dataset
36 | from torchvision import transforms
37 | from tqdm.auto import tqdm
38 | from transformers import AutoTokenizer, PretrainedConfig
39 |
40 | import diffusers
41 | from diffusers import (
42 | AutoencoderKL,
43 | DDPMScheduler,
44 | DiffusionPipeline,
45 | DPMSolverMultistepScheduler,
46 | UNet2DConditionModel,
47 | )
48 | from utils.quant_model import QuantUnetWarp
49 | from diffusers.optimization import get_scheduler
50 | from diffusers.utils import check_min_version, is_wandb_available
51 | from diffusers.utils.import_utils import is_xformers_available
52 | from utils.intlora_shift import IntLoRA_SHIFT
53 | from utils.intlora_mul import IntLoRA_MUL
54 |
55 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
56 | check_min_version("0.16.0.dev0")
57 | logger = get_logger(__name__)
58 |
59 |
60 |
61 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
62 | text_encoder_config = PretrainedConfig.from_pretrained(
63 | pretrained_model_name_or_path,
64 | subfolder="text_encoder",
65 | revision=revision,
66 | )
67 | model_class = text_encoder_config.architectures[0]
68 |
69 | if model_class == "CLIPTextModel":
70 | from transformers import CLIPTextModel
71 |
72 | return CLIPTextModel
73 | elif model_class == "RobertaSeriesModelWithTransformation":
74 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
75 |
76 | return RobertaSeriesModelWithTransformation
77 | else:
78 | raise ValueError(f"{model_class} is not supported.")
79 |
80 |
81 | def parse_args(input_args=None):
82 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
83 | parser.add_argument(
84 | "--pretrained_model_name_or_path",
85 | type=str,
86 | default=None,
87 | required=True,
88 | help="Path to pretrained model or model identifier from huggingface.co/models.",
89 | )
90 |
91 | parser.add_argument(
92 | "--intlora",
93 | type=str,
94 | default='MUL',
95 | required=False,
96 | )
97 |
98 |
99 | parser.add_argument(
100 | "--nbits",
101 | type=int,
102 | default=8,
103 | required=False,
104 | )
105 |
106 | parser.add_argument("--rank", type=int, default=4, help="The inner rank of the lora layer")
107 |
108 | parser.add_argument("--act_nbits", type=int, default=8, help="The nbits of the activatin quantization")
109 |
110 | parser.add_argument(
111 | "--use_activation_quant",
112 | default=False,
113 | action="store_true",
114 | help="Whether to use the ativation quantization")
115 |
116 | parser.add_argument(
117 | "--revision",
118 | type=str,
119 | default=None,
120 | required=False,
121 | help="Revision of pretrained model identifier from huggingface.co/models.",
122 | )
123 | parser.add_argument(
124 | "--tokenizer_name",
125 | type=str,
126 | default=None,
127 | help="Pretrained tokenizer name or path if not the same as model_name",
128 | )
129 | parser.add_argument(
130 | "--instance_data_dir",
131 | type=str,
132 | default=None,
133 | required=True,
134 | help="A folder containing the training data of instance images.",
135 | )
136 | parser.add_argument(
137 | "--class_data_dir",
138 | type=str,
139 | default=None,
140 | required=False,
141 | help="A folder containing the training data of class images.",
142 | )
143 | parser.add_argument(
144 | "--instance_prompt",
145 | type=str,
146 | default=None,
147 | required=True,
148 | help="The prompt with identifier specifying the instance",
149 | )
150 | parser.add_argument(
151 | "--class_prompt",
152 | type=str,
153 | default=None,
154 | help="The prompt to specify images in the same class as provided instance images.",
155 | )
156 | parser.add_argument(
157 | "--validation_prompt",
158 | type=str,
159 | default=None,
160 | help="A prompt that is used during validation to verify that the model is learning.",
161 | )
162 | parser.add_argument(
163 | "--test_prompt",
164 | type=str,
165 | default=None,
166 | help="A prompt that is used during validation to verify that the model is keeps class prior.",
167 | )
168 | parser.add_argument(
169 | "--num_validation_images",
170 | type=int,
171 | default=8,
172 | help="Number of images that should be generated during validation with `validation_prompt`.",
173 | )
174 | parser.add_argument(
175 | "--validation_epochs",
176 | type=int,
177 | default=50,
178 | help=(
179 | "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
180 | " `args.validation_prompt` multiple times: `args.num_validation_images`."
181 | ),
182 | )
183 | parser.add_argument(
184 | "--with_prior_preservation",
185 | default=False,
186 | action="store_true",
187 | help="Flag to add prior preservation loss.",
188 | )
189 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
190 | parser.add_argument(
191 | "--num_class_images",
192 | type=int,
193 | default=100,
194 | help=(
195 | "Minimal class images for prior preservation loss. If there are not enough images already present in"
196 | " class_data_dir, additional images will be sampled with class_prompt."
197 | ),
198 | )
199 | parser.add_argument(
200 | "--output_dir",
201 | type=str,
202 | default="lora-dreambooth-model",
203 | help="The output directory where the model predictions and checkpoints will be written.",
204 | )
205 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
206 | parser.add_argument(
207 | "--resolution",
208 | type=int,
209 | default=512,
210 | help=(
211 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
212 | " resolution"
213 | ),
214 | )
215 | parser.add_argument(
216 | "--center_crop",
217 | default=False,
218 | action="store_true",
219 | help=(
220 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
221 | " cropped. The images will be resized to the resolution first before cropping."
222 | ),
223 | )
224 | parser.add_argument(
225 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
226 | )
227 | parser.add_argument(
228 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
229 | )
230 | parser.add_argument("--num_train_epochs", type=int, default=1)
231 | parser.add_argument(
232 | "--max_train_steps",
233 | type=int,
234 | default=None,
235 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
236 | )
237 | parser.add_argument(
238 | "--checkpointing_steps",
239 | type=int,
240 | default=500,
241 | help=(
242 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
243 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
244 | " training using `--resume_from_checkpoint`."
245 | ),
246 | )
247 | parser.add_argument(
248 | "--checkpoints_total_limit",
249 | type=int,
250 | default=None,
251 | help=(
252 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
253 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
254 | " for more docs"
255 | ),
256 | )
257 | parser.add_argument(
258 | "--resume_from_checkpoint",
259 | type=str,
260 | default=None,
261 | help=(
262 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
263 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
264 | ),
265 | )
266 | parser.add_argument(
267 | "--gradient_accumulation_steps",
268 | type=int,
269 | default=1,
270 | help="Number of updates steps to accumulate before performing a backward/update pass.",
271 | )
272 | parser.add_argument(
273 | "--gradient_checkpointing",
274 | action="store_true",
275 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
276 | )
277 | parser.add_argument(
278 | "--learning_rate",
279 | type=float,
280 | default=5e-4,
281 | help="Initial learning rate (after the potential warmup period) to use.",
282 | )
283 | parser.add_argument(
284 | "--scale_lr",
285 | action="store_true",
286 | default=False,
287 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
288 | )
289 | parser.add_argument(
290 | "--lr_scheduler",
291 | type=str,
292 | default="constant",
293 | help=(
294 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
295 | ' "constant", "constant_with_warmup"]'
296 | ),
297 | )
298 | parser.add_argument(
299 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
300 | )
301 | parser.add_argument(
302 | "--lr_num_cycles",
303 | type=int,
304 | default=1,
305 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
306 | )
307 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
308 | parser.add_argument(
309 | "--dataloader_num_workers",
310 | type=int,
311 | default=0,
312 | help=(
313 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
314 | ),
315 | )
316 | parser.add_argument(
317 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
318 | )
319 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
320 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
321 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
322 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
323 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
324 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
325 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
326 | parser.add_argument(
327 | "--hub_model_id",
328 | type=str,
329 | default=None,
330 | help="The name of the repository to keep in sync with the local `output_dir`.",
331 | )
332 | parser.add_argument(
333 | "--logging_dir",
334 | type=str,
335 | default="logs",
336 | help=(
337 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
338 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
339 | ),
340 | )
341 | parser.add_argument(
342 | "--allow_tf32",
343 | action="store_true",
344 | help=(
345 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
346 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
347 | ),
348 | )
349 | parser.add_argument(
350 | "--report_to",
351 | type=str,
352 | default="tensorboard",
353 | help=(
354 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
355 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
356 | ),
357 | )
358 | parser.add_argument(
359 | "--mixed_precision",
360 | type=str,
361 | default=None,
362 | choices=["no", "fp16", "bf16"],
363 | help=(
364 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
365 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
366 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
367 | ),
368 | )
369 | parser.add_argument(
370 | "--prior_generation_precision",
371 | type=str,
372 | default=None,
373 | choices=["no", "fp32", "fp16", "bf16"],
374 | help=(
375 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
376 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
377 | ),
378 | )
379 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
380 | parser.add_argument(
381 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
382 | )
383 | parser.add_argument(
384 | "--name",
385 | type=str,
386 | help=(
387 | "The name of the current experiment run, consists of [data]-[prompt]"
388 | ),
389 | )
390 |
391 |
392 |
393 |
394 | if input_args is not None:
395 | args = parser.parse_args(input_args)
396 | else:
397 | args = parser.parse_args()
398 |
399 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
400 | if env_local_rank != -1 and env_local_rank != args.local_rank:
401 | args.local_rank = env_local_rank
402 |
403 | if args.with_prior_preservation:
404 | if args.class_data_dir is None:
405 | raise ValueError("You must specify a data directory for class images.")
406 | if args.class_prompt is None:
407 | raise ValueError("You must specify prompt for class images.")
408 | else:
409 | # logger is not available yet
410 | if args.class_data_dir is not None:
411 | warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
412 | if args.class_prompt is not None:
413 | warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
414 |
415 | return args
416 |
417 |
418 | class DreamBoothDataset(Dataset):
419 | """
420 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
421 | It pre-processes the images and the tokenizes prompts.
422 | """
423 |
424 | def __init__(
425 | self,
426 | instance_data_root,
427 | instance_prompt,
428 | tokenizer,
429 | class_data_root=None,
430 | class_prompt=None,
431 | class_num=None,
432 | size=512,
433 | center_crop=False,
434 | ):
435 | self.size = size
436 | self.center_crop = center_crop
437 | self.tokenizer = tokenizer
438 |
439 | self.instance_data_root = Path(instance_data_root)
440 | if not self.instance_data_root.exists():
441 | raise ValueError("Instance images root doesn't exists.")
442 |
443 | self.instance_images_path = list(Path(instance_data_root).iterdir())
444 | self.num_instance_images = len(self.instance_images_path)
445 | self.instance_prompt = instance_prompt
446 | self._length = self.num_instance_images
447 |
448 | if class_data_root is not None:
449 | self.class_data_root = Path(class_data_root)
450 | self.class_data_root.mkdir(parents=True, exist_ok=True)
451 | self.class_images_path = list(self.class_data_root.iterdir())
452 | if class_num is not None:
453 | self.num_class_images = min(len(self.class_images_path), class_num)
454 | else:
455 | self.num_class_images = len(self.class_images_path)
456 | self._length = max(self.num_class_images, self.num_instance_images)
457 | self.class_prompt = class_prompt
458 | else:
459 | self.class_data_root = None
460 |
461 | self.image_transforms = transforms.Compose(
462 | [
463 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
464 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
465 | transforms.ToTensor(),
466 | transforms.Normalize([0.5], [0.5]),
467 | ]
468 | )
469 |
470 | def __len__(self):
471 | return self._length
472 |
473 | def __getitem__(self, index):
474 | example = {}
475 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
476 | if not instance_image.mode == "RGB":
477 | instance_image = instance_image.convert("RGB")
478 | example["instance_images"] = self.image_transforms(instance_image)
479 | example["instance_prompt_ids"] = self.tokenizer(
480 | self.instance_prompt,
481 | truncation=True,
482 | padding="max_length",
483 | max_length=self.tokenizer.model_max_length,
484 | return_tensors="pt",
485 | ).input_ids
486 |
487 | if self.class_data_root:
488 | class_image = Image.open(self.class_images_path[index % self.num_class_images])
489 | if not class_image.mode == "RGB":
490 | class_image = class_image.convert("RGB")
491 | example["class_images"] = self.image_transforms(class_image)
492 | example["class_prompt_ids"] = self.tokenizer(
493 | self.class_prompt,
494 | truncation=True,
495 | padding="max_length",
496 | max_length=self.tokenizer.model_max_length,
497 | return_tensors="pt",
498 | ).input_ids
499 |
500 | return example
501 |
502 |
503 | def collate_fn(examples, with_prior_preservation=False):
504 | input_ids = [example["instance_prompt_ids"] for example in examples]
505 | pixel_values = [example["instance_images"] for example in examples]
506 |
507 | # Concat class and instance examples for prior preservation.
508 | # We do this to avoid doing two forward passes.
509 | if with_prior_preservation:
510 | input_ids += [example["class_prompt_ids"] for example in examples]
511 | pixel_values += [example["class_images"] for example in examples]
512 |
513 | pixel_values = torch.stack(pixel_values)
514 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
515 |
516 | input_ids = torch.cat(input_ids, dim=0)
517 |
518 | batch = {
519 | "input_ids": input_ids,
520 | "pixel_values": pixel_values,
521 | }
522 | return batch
523 |
524 |
525 | class PromptDataset(Dataset):
526 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
527 |
528 | def __init__(self, prompt, num_samples):
529 | self.prompt = prompt
530 | self.num_samples = num_samples
531 |
532 | def __len__(self):
533 | return self.num_samples
534 |
535 | def __getitem__(self, index):
536 | example = {}
537 | example["prompt"] = self.prompt
538 | example["index"] = index
539 | return example
540 |
541 |
542 | def main(args):
543 | logging_dir = Path(args.output_dir, args.logging_dir)
544 |
545 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) # total_limit=args.checkpoints_total_limit)
546 |
547 | wandb_init = {
548 | "wandb": {
549 | "name": args.name,
550 | # "project": args.project,
551 | }
552 | }
553 |
554 | accelerator = Accelerator(
555 | gradient_accumulation_steps=args.gradient_accumulation_steps,
556 | mixed_precision=args.mixed_precision,
557 | log_with=args.report_to,
558 | project_config=accelerator_project_config,
559 | )
560 |
561 | if args.report_to == "wandb":
562 | if not is_wandb_available():
563 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
564 | import wandb
565 |
566 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
567 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
568 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
569 | # Make one log on every process with the configuration for debugging.
570 | logging.basicConfig(
571 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
572 | datefmt="%m/%d/%Y %H:%M:%S",
573 | level=logging.INFO,
574 | )
575 | logger.info(accelerator.state, main_process_only=False)
576 | if accelerator.is_local_main_process:
577 | transformers.utils.logging.set_verbosity_warning()
578 | diffusers.utils.logging.set_verbosity_info()
579 | else:
580 | transformers.utils.logging.set_verbosity_error()
581 | diffusers.utils.logging.set_verbosity_error()
582 |
583 | # If passed along, set the training seed now.
584 | if args.seed is not None:
585 | set_seed(args.seed)
586 |
587 | # Generate class images if prior preservation is enabled.
588 | if args.with_prior_preservation:
589 | class_images_dir = Path(args.class_data_dir)
590 | if not class_images_dir.exists():
591 | class_images_dir.mkdir(parents=True)
592 | cur_class_images = len(list(class_images_dir.iterdir()))
593 |
594 | if cur_class_images < args.num_class_images:
595 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
596 | if args.prior_generation_precision == "fp32":
597 | torch_dtype = torch.float32
598 | elif args.prior_generation_precision == "fp16":
599 | torch_dtype = torch.float16
600 | elif args.prior_generation_precision == "bf16":
601 | torch_dtype = torch.bfloat16
602 | pipeline = DiffusionPipeline.from_pretrained(
603 | args.pretrained_model_name_or_path,
604 | torch_dtype=torch_dtype,
605 | safety_checker=None,
606 | revision=args.revision,
607 | )
608 | pipeline.set_progress_bar_config(disable=True)
609 |
610 | num_new_images = args.num_class_images - cur_class_images
611 | logger.info(f"Number of class images to sample: {num_new_images}.")
612 |
613 | sample_dataset = PromptDataset(args.class_prompt, num_new_images)
614 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
615 |
616 | sample_dataloader = accelerator.prepare(sample_dataloader)
617 | pipeline.to(accelerator.device)
618 |
619 | for example in tqdm(
620 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
621 | ):
622 | images = pipeline(example["prompt"]).images
623 |
624 | for i, image in enumerate(images):
625 | hash_image = hashlib.sha1(image.tobytes()).hexdigest()
626 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
627 | image.save(image_filename)
628 |
629 | del pipeline
630 | if torch.cuda.is_available():
631 | torch.cuda.empty_cache()
632 |
633 | # Handle the repository creation
634 | if accelerator.is_main_process:
635 | if args.output_dir is not None:
636 | os.makedirs(args.output_dir, exist_ok=True)
637 |
638 | if args.push_to_hub:
639 | repo_id = create_repo(
640 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
641 | ).repo_id
642 |
643 | # Load the tokenizer
644 | if args.tokenizer_name:
645 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
646 | elif args.pretrained_model_name_or_path:
647 | tokenizer = AutoTokenizer.from_pretrained(
648 | args.pretrained_model_name_or_path,
649 | subfolder="tokenizer",
650 | revision=args.revision,
651 | use_fast=False,
652 | )
653 |
654 | # import correct text encoder class
655 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
656 |
657 | # Load scheduler and models
658 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
659 | text_encoder = text_encoder_cls.from_pretrained(
660 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
661 | )
662 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
663 | unet = UNet2DConditionModel.from_pretrained(
664 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
665 | )
666 |
667 | # We only train the additional adapter OFT layers
668 | vae.requires_grad_(False)
669 | text_encoder.requires_grad_(False)
670 | unet.requires_grad_(False)
671 |
672 | # For mixed precision training we cast the text_encoder and vae weights to half-precision
673 | # as these models are only used for inference, keeping weights in full precision is not required.
674 | weight_dtype = torch.float32
675 | if accelerator.mixed_precision == "fp16":
676 | weight_dtype = torch.float16
677 | elif accelerator.mixed_precision == "bf16":
678 | weight_dtype = torch.bfloat16
679 |
680 | # Move unet, vae and text_encoder to device and cast to weight_dtype
681 | unet.to(accelerator.device, dtype=weight_dtype)
682 | vae.to(accelerator.device, dtype=weight_dtype)
683 | text_encoder.to(accelerator.device, dtype=weight_dtype)
684 |
685 | if args.enable_xformers_memory_efficient_attention:
686 | if is_xformers_available():
687 | import xformers
688 |
689 | xformers_version = version.parse(xformers.__version__)
690 | if xformers_version == version.parse("0.0.16"):
691 | logger.warn(
692 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
693 | )
694 | unet.enable_xformers_memory_efficient_attention()
695 | else:
696 | raise ValueError("xformers is not available. Make sure it is installed correctly")
697 |
698 |
699 | # ---------------------------------------------------------------------------------------------------------------
700 | unet = QuantUnetWarp(unet,args)
701 | opt_params = []
702 | for name, module in unet.named_modules():
703 | if isinstance(module, (IntLoRA_MUL,IntLoRA_SHIFT)):
704 | opt_params += list(module.loraA.parameters()) + list(module.loraB.parameters())
705 |
706 | # ---------------------------------------------------------------------------------------------------------------
707 | # Enable TF32 for faster training on Ampere GPUs,
708 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
709 | if args.allow_tf32:
710 | torch.backends.cuda.matmul.allow_tf32 = True
711 |
712 | if args.scale_lr:
713 | args.learning_rate = (
714 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
715 | )
716 |
717 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
718 | if args.use_8bit_adam:
719 | try:
720 | import bitsandbytes as bnb
721 | except ImportError:
722 | raise ImportError(
723 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
724 | )
725 |
726 | optimizer_class = bnb.optim.AdamW8bit
727 | else:
728 | optimizer_class = torch.optim.AdamW
729 |
730 | # Optimizer creation
731 | optimizer = optimizer_class(
732 | opt_params, # quant_layers.parameters(), #opt_params,
733 | lr=args.learning_rate, # 6e-5
734 | betas=(args.adam_beta1, args.adam_beta2),
735 | weight_decay=args.adam_weight_decay, # 0.01
736 | eps=args.adam_epsilon, # 1e-8
737 | )
738 |
739 | # Dataset and DataLoaders creation:
740 | train_dataset = DreamBoothDataset(
741 | instance_data_root=args.instance_data_dir,
742 | instance_prompt=args.instance_prompt,
743 | class_data_root=args.class_data_dir if args.with_prior_preservation else None,
744 | class_prompt=args.class_prompt,
745 | class_num=args.num_class_images,
746 | tokenizer=tokenizer,
747 | size=args.resolution,
748 | center_crop=args.center_crop,
749 | )
750 |
751 | train_dataloader = torch.utils.data.DataLoader(
752 | train_dataset,
753 | batch_size=args.train_batch_size,
754 | shuffle=True,
755 | collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
756 | num_workers=args.dataloader_num_workers,
757 | )
758 |
759 | # Scheduler and math around the number of training steps.
760 | overrode_max_train_steps = False
761 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
762 | if args.max_train_steps is None:
763 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
764 | overrode_max_train_steps = True
765 |
766 | lr_scheduler = get_scheduler(
767 | args.lr_scheduler,
768 | optimizer=optimizer,
769 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
770 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
771 | num_cycles=args.lr_num_cycles,
772 | power=args.lr_power,
773 | )
774 |
775 | # Prepare everything with our `accelerator`.
776 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
777 | unet, optimizer, train_dataloader, lr_scheduler
778 | )
779 |
780 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
781 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
782 | if overrode_max_train_steps:
783 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
784 | # Afterwards we recalculate our number of training epochs
785 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
786 |
787 | # We need to initialize the trackers we use, and also store our configuration.
788 | # The trackers initializes automatically on the main process.
789 | if accelerator.is_main_process:
790 | accelerator.init_trackers("dreambooth-quant", config=vars(args), init_kwargs=wandb_init)
791 |
792 | # Train!
793 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
794 |
795 | logger.info("***** Running training *****")
796 | logger.info(f" Num examples = {len(train_dataset)}")
797 | logger.info(f" Num batches each epoch = {len(train_dataloader)}")
798 | logger.info(f" Num Epochs = {args.num_train_epochs}")
799 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
800 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
801 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
802 | logger.info(f" Total optimization steps = {args.max_train_steps}")
803 | global_step = 0
804 | first_epoch = 0
805 |
806 | # Potentially load in the weights and states from a previous save
807 | if args.resume_from_checkpoint:
808 | if args.resume_from_checkpoint != "latest":
809 | path = os.path.basename(args.resume_from_checkpoint)
810 | else:
811 | # Get the mos recent checkpoint
812 | dirs = os.listdir(args.output_dir)
813 | dirs = [d for d in dirs if d.startswith("checkpoint")]
814 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
815 | path = dirs[-1] if len(dirs) > 0 else None
816 |
817 | if path is None:
818 | accelerator.print(
819 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
820 | )
821 | args.resume_from_checkpoint = None
822 | else:
823 | accelerator.print(f"Resuming from checkpoint {path}")
824 | accelerator.load_state(os.path.join(args.output_dir, path))
825 | global_step = int(path.split("-")[1])
826 |
827 | resume_global_step = global_step * args.gradient_accumulation_steps
828 | first_epoch = global_step // num_update_steps_per_epoch
829 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
830 |
831 | # Only show the progress bar once on each machine.
832 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
833 | progress_bar.set_description("Steps")
834 |
835 |
836 | for epoch in range(first_epoch, args.num_train_epochs):
837 | unet.train()
838 |
839 | for step, batch in enumerate(train_dataloader):
840 | # Skip steps until we reach the resumed step
841 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
842 | if step % args.gradient_accumulation_steps == 0:
843 | progress_bar.update(1)
844 | continue
845 |
846 | with accelerator.accumulate(unet):
847 | # Convert images to latent space
848 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
849 | latents = latents * vae.config.scaling_factor
850 |
851 | # Sample noise that we'll add to the latents
852 | noise = torch.randn_like(latents)
853 | bsz = latents.shape[0]
854 | # Sample a random timestep for each image
855 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
856 | timesteps = timesteps.long()
857 |
858 | # Add noise to the latents according to the noise magnitude at each timestep
859 | # (this is the forward diffusion process)
860 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
861 |
862 | # Get the text embedding for conditioning
863 | encoder_hidden_states = text_encoder(batch["input_ids"])[0]
864 |
865 | # Predict the noise residual
866 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
867 |
868 | # Get the target for loss depending on the prediction type
869 | if noise_scheduler.config.prediction_type == "epsilon":
870 | target = noise
871 | elif noise_scheduler.config.prediction_type == "v_prediction":
872 | target = noise_scheduler.get_velocity(latents, noise, timesteps)
873 | else:
874 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
875 |
876 | if args.with_prior_preservation:
877 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
878 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
879 | target, target_prior = torch.chunk(target, 2, dim=0)
880 |
881 | # Compute instance loss
882 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
883 |
884 | # Compute prior loss
885 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
886 |
887 | # Add the prior loss to the instance loss.
888 | loss = loss + args.prior_loss_weight * prior_loss
889 | else:
890 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
891 |
892 | accelerator.backward(loss)
893 | if accelerator.sync_gradients:
894 | params_to_clip = opt_params # quant_layers.parameters() # opt_params
895 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
896 | optimizer.step()
897 | lr_scheduler.step()
898 | optimizer.zero_grad()
899 |
900 | # Checks if the accelerator has performed an optimization step behind the scenes
901 | if accelerator.sync_gradients:
902 | progress_bar.update(1)
903 | global_step += 1
904 |
905 | if global_step % args.checkpointing_steps == 0:
906 | if accelerator.is_main_process:
907 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
908 | accelerator.save_state(save_path)
909 | logger.info(f"Saved state to {save_path}")
910 |
911 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
912 | progress_bar.set_postfix(**logs)
913 | accelerator.log(logs, step=global_step)
914 |
915 | if global_step >= args.max_train_steps:
916 | break
917 |
918 | if accelerator.is_main_process:
919 | if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # and epoch > 1:
920 | logger.info(
921 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
922 | f" {args.validation_prompt}."
923 | )
924 |
925 | # create pipeline
926 | pipeline = DiffusionPipeline.from_pretrained(
927 | args.pretrained_model_name_or_path,
928 | unet= accelerator.unwrap_model(unet.model) if isinstance(unet,QuantUnetWarp) else accelerator.unwrap_model(unet),
929 | text_encoder=accelerator.unwrap_model(text_encoder),
930 | revision=args.revision,
931 | torch_dtype=weight_dtype,
932 | )
933 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
934 | pipeline = pipeline.to(accelerator.device)
935 | pipeline.set_progress_bar_config(disable=True)
936 |
937 | # run inference
938 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
939 | images = [
940 | pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
941 | for _ in range(args.num_validation_images)
942 | ]
943 |
944 | for tracker in accelerator.trackers:
945 | if tracker.name == "tensorboard":
946 | np_images = np.stack([np.asarray(img) for img in images])
947 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
948 | if tracker.name == "wandb":
949 | tracker.log(
950 | {
951 | "validation": [
952 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
953 | for i, image in enumerate(images)
954 | ]
955 | }
956 | )
957 |
958 | # Create the output directory if it doesn't exist
959 | tmp_dir = os.path.join(args.output_dir, str(epoch))
960 | if not os.path.exists(tmp_dir):
961 | os.makedirs(tmp_dir)
962 |
963 | for i, image in enumerate(images):
964 | np_image = np.array(image)
965 | pil_image = Image.fromarray(np_image)
966 | pil_image.save(os.path.join(args.output_dir, str(epoch), f"image_{i}.png"))
967 |
968 | del pipeline
969 | torch.cuda.empty_cache()
970 |
971 |
972 | # Save the oft layers
973 | accelerator.wait_for_everyone()
974 | if accelerator.is_main_process:
975 | unet = unet.to(torch.float32)
976 | unet.save_attn_procs(args.output_dir)
977 |
978 | # Final inference
979 | # Load previous pipeline
980 | pipeline = DiffusionPipeline.from_pretrained(
981 | args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
982 | )
983 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
984 | pipeline = pipeline.to(accelerator.device)
985 |
986 | # load attention processors
987 | pipeline.unet.load_attn_procs(args.output_dir)
988 |
989 | # run inference
990 | if args.validation_prompt and args.num_validation_images > 0:
991 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
992 | images = [
993 | pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
994 | for _ in range(args.num_validation_images)
995 | ]
996 |
997 | for tracker in accelerator.trackers:
998 | if tracker.name == "tensorboard":
999 | np_images = np.stack([np.asarray(img) for img in images])
1000 | tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1001 | if tracker.name == "wandb":
1002 | tracker.log(
1003 | {
1004 | "test": [
1005 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1006 | for i, image in enumerate(images)
1007 | ]
1008 | }
1009 | )
1010 |
1011 | accelerator.end_training()
1012 |
1013 |
1014 | if __name__ == "__main__":
1015 | args = parse_args()
1016 | main(args)
1017 |
--------------------------------------------------------------------------------
/train_dreambooth_quant.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME="/data/guohang/pretrained/stable-diffusion-v1-5"
2 |
3 | for idx in {0..749}
4 | do
5 | prompt_idx=$((idx % 25))
6 | class_idx=$((idx / 25))
7 |
8 |
9 | # Define the unique_token, class_tokens, and subject_names
10 | unique_token="qwe"
11 | subject_names=(
12 | "backpack" "backpack_dog" "bear_plushie" "berry_bowl" "can"
13 | "candle" "cat" "cat2" "clock" "colorful_sneaker"
14 | "dog" "dog2" "dog3" "dog5" "dog6"
15 | "dog7" "dog8" "duck_toy" "fancy_boot" "grey_sloth_plushie"
16 | "monster_toy" "pink_sunglasses" "poop_emoji" "rc_car" "red_cartoon"
17 | "robot_toy" "shiny_sneaker" "teapot" "vase" "wolf_plushie"
18 | )
19 |
20 | class_tokens=(
21 | "backpack" "backpack" "stuffed animal" "bowl" "can"
22 | "candle" "cat" "cat" "clock" "sneaker"
23 | "dog" "dog" "dog" "dog" "dog"
24 | "dog" "dog" "toy" "boot" "stuffed animal"
25 | "toy" "glasses" "toy" "toy" "cartoon"
26 | "toy" "sneaker" "teapot" "vase" "stuffed animal"
27 | )
28 |
29 | class_token=${class_tokens[$class_idx]}
30 | selected_subject=${subject_names[$class_idx]}
31 |
32 | if [[ $class_idx =~ ^(0|1|2|3|4|5|8|9|17|18|19|20|21|22|23|24|25|26|27|28|29)$ ]]; then
33 | prompt_list=(
34 | "a ${unique_token} ${class_token} in the jungle"
35 | "a ${unique_token} ${class_token} in the snow"
36 | "a ${unique_token} ${class_token} on the beach"
37 | "a ${unique_token} ${class_token} on a cobblestone street"
38 | "a ${unique_token} ${class_token} on top of pink fabric"
39 | "a ${unique_token} ${class_token} on top of a wooden floor"
40 | "a ${unique_token} ${class_token} with a city in the background"
41 | "a ${unique_token} ${class_token} with a mountain in the background"
42 | "a ${unique_token} ${class_token} with a blue house in the background"
43 | "a ${unique_token} ${class_token} on top of a purple rug in a forest"
44 | "a ${unique_token} ${class_token} with a wheat field in the background"
45 | "a ${unique_token} ${class_token} with a tree and autumn leaves in the background"
46 | "a ${unique_token} ${class_token} with the Eiffel Tower in the background"
47 | "a ${unique_token} ${class_token} floating on top of water"
48 | "a ${unique_token} ${class_token} floating in an ocean of milk"
49 | "a ${unique_token} ${class_token} on top of green grass with sunflowers around it"
50 | "a ${unique_token} ${class_token} on top of a mirror"
51 | "a ${unique_token} ${class_token} on top of the sidewalk in a crowded street"
52 | "a ${unique_token} ${class_token} on top of a dirt road"
53 | "a ${unique_token} ${class_token} on top of a white rug"
54 | "a red ${unique_token} ${class_token}"
55 | "a purple ${unique_token} ${class_token}"
56 | "a shiny ${unique_token} ${class_token}"
57 | "a wet ${unique_token} ${class_token}"
58 | "a cube shaped ${unique_token} ${class_token}"
59 | )
60 |
61 | prompt_test_list=(
62 | "a ${class_token} in the jungle"
63 | "a ${class_token} in the snow"
64 | "a ${class_token} on the beach"
65 | "a ${class_token} on a cobblestone street"
66 | "a ${class_token} on top of pink fabric"
67 | "a ${class_token} on top of a wooden floor"
68 | "a ${class_token} with a city in the background"
69 | "a ${class_token} with a mountain in the background"
70 | "a ${class_token} with a blue house in the background"
71 | "a ${class_token} on top of a purple rug in a forest"
72 | "a ${class_token} with a wheat field in the background"
73 | "a ${class_token} with a tree and autumn leaves in the background"
74 | "a ${class_token} with the Eiffel Tower in the background"
75 | "a ${class_token} floating on top of water"
76 | "a ${class_token} floating in an ocean of milk"
77 | "a ${class_token} on top of green grass with sunflowers around it"
78 | "a ${class_token} on top of a mirror"
79 | "a ${class_token} on top of the sidewalk in a crowded street"
80 | "a ${class_token} on top of a dirt road"
81 | "a ${class_token} on top of a white rug"
82 | "a red ${class_token}"
83 | "a purple ${class_token}"
84 | "a shiny ${class_token}"
85 | "a wet ${class_token}"
86 | "a cube shaped ${class_token}"
87 | )
88 |
89 | else
90 | prompt_list=(
91 | "a ${unique_token} ${class_token} in the jungle"
92 | "a ${unique_token} ${class_token} in the snow"
93 | "a ${unique_token} ${class_token} on the beach"
94 | "a ${unique_token} ${class_token} on a cobblestone street"
95 | "a ${unique_token} ${class_token} on top of pink fabric"
96 | "a ${unique_token} ${class_token} on top of a wooden floor"
97 | "a ${unique_token} ${class_token} with a city in the background"
98 | "a ${unique_token} ${class_token} with a mountain in the background"
99 | "a ${unique_token} ${class_token} with a blue house in the background"
100 | "a ${unique_token} ${class_token} on top of a purple rug in a forest"
101 | "a ${unique_token} ${class_token} wearing a red hat"
102 | "a ${unique_token} ${class_token} wearing a santa hat"
103 | "a ${unique_token} ${class_token} wearing a rainbow scarf"
104 | "a ${unique_token} ${class_token} wearing a black top hat and a monocle"
105 | "a ${unique_token} ${class_token} in a chef outfit"
106 | "a ${unique_token} ${class_token} in a firefighter outfit"
107 | "a ${unique_token} ${class_token} in a police outfit"
108 | "a ${unique_token} ${class_token} wearing pink glasses"
109 | "a ${unique_token} ${class_token} wearing a yellow shirt"
110 | "a ${unique_token} ${class_token} in a purple wizard outfit"
111 | "a red ${unique_token} ${class_token}"
112 | "a purple ${unique_token} ${class_token}"
113 | "a shiny ${unique_token} ${class_token}"
114 | "a wet ${unique_token} ${class_token}"
115 | "a cube shaped ${unique_token} ${class_token}"
116 | )
117 |
118 | prompt_test_list=(
119 | "a ${class_token} in the jungle"
120 | "a ${class_token} in the snow"
121 | "a ${class_token} on the beach"
122 | "a ${class_token} on a cobblestone street"
123 | "a ${class_token} on top of pink fabric"
124 | "a ${class_token} on top of a wooden floor"
125 | "a ${class_token} with a city in the background"
126 | "a ${class_token} with a mountain in the background"
127 | "a ${class_token} with a blue house in the background"
128 | "a ${class_token} on top of a purple rug in a forest"
129 | "a ${class_token} wearing a red hat"
130 | "a ${class_token} wearing a santa hat"
131 | "a ${class_token} wearing a rainbow scarf"
132 | "a ${class_token} wearing a black top hat and a monocle"
133 | "a ${class_token} in a chef outfit"
134 | "a ${class_token} in a firefighter outfit"
135 | "a ${class_token} in a police outfit"
136 | "a ${class_token} wearing pink glasses"
137 | "a ${class_token} wearing a yellow shirt"
138 | "a ${class_token} in a purple wizard outfit"
139 | "a red ${class_token}"
140 | "a purple ${class_token}"
141 | "a shiny ${class_token}"
142 | "a wet ${class_token}"
143 | "a cube shaped ${class_token}"
144 | )
145 | fi
146 |
147 |
148 | validation_prompt=${prompt_list[$prompt_idx]}
149 | test_prompt=${prompt_test_list[$prompt_idx]}
150 | name="${selected_subject}-${prompt_idx}"
151 | instance_prompt="a photo of ${unique_token} ${class_token}"
152 | class_prompt="a photo of ${class_token}"
153 |
154 | export OUTPUT_DIR="log_quant/${name}"
155 | export INSTANCE_DIR="/data/guohang/dataset/dreambooth/${selected_subject}"
156 | export CLASS_DIR="/data/guohang/ControlNet/oft-db/data/class_data/${class_token}"
157 |
158 |
159 |
160 | python train_dreambooth_quant.py \
161 | --pretrained_model_name_or_path=$MODEL_NAME \
162 | --instance_data_dir=$INSTANCE_DIR \
163 | --rank=4 \
164 | --intlora=MUL \
165 | --nbits=8 \
166 | --use_activation_quant \
167 | --act_nbits=8 \
168 | --class_data_dir="$CLASS_DIR" \
169 | --output_dir=$OUTPUT_DIR \
170 | --instance_prompt="$instance_prompt" \
171 | --with_prior_preservation \
172 | --prior_loss_weight=1.0 \
173 | --class_prompt="$class_prompt" \
174 | --resolution=512 \
175 | --train_batch_size=1 \
176 | --gradient_accumulation_steps=1 \
177 | --checkpointing_steps=5000 \
178 | --learning_rate=6e-5 \
179 | --report_to="wandb" \
180 | --lr_scheduler="constant" \
181 | --lr_warmup_steps=0 \
182 | --max_train_steps=2000 \
183 | --validation_prompt="$validation_prompt" \
184 | --num_validation_images=5 \
185 | --validation_epochs=1 \
186 | --seed="0" \
187 | --name="$name" \
188 | --num_class_images=200 \
189 | #--gradient_checkpointing #one can uncomment the GC to reduce the GPU memory
190 | done
191 |
192 |
193 |
--------------------------------------------------------------------------------
/utils/intlora_mul.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from utils.quant_utils import batch_mse,batch_max,lp_loss,round_ste
6 | from utils.quant_layer import UniformAffineQuantizer
7 |
8 |
9 |
10 | class IntLoRA_MUL(nn.Module):
11 | # The implementation of our IntLoRA-MUL
12 | def __init__(self, org_module, n_bits=8, lora_bits=8, symmetric=False,channel_wise=True, rank=4, activation_params=None):
13 | super(IntLoRA_MUL, self).__init__()
14 |
15 | self.n_bits = n_bits
16 | self.lora_bits = lora_bits
17 | self.sym = symmetric
18 | self.scale_method = 'mse'
19 | self.always_zero = False
20 | self.n_levels = 2 ** self.n_bits if not self.sym else 2 ** (self.n_bits - 1) - 1
21 | self.channel_wise = channel_wise
22 | self.inited=False
23 | self.lora_levels = self.n_levels
24 |
25 | self.fwd_kwargs = dict()
26 | self.fwd_func = F.linear
27 | self.in_features = org_module.in_features
28 | self.out_features = org_module.out_features
29 |
30 |
31 | # save original weights and bias and keep them intact
32 | self.ori_weight_shape = org_module.weight.shape
33 |
34 | self.ori_weight = org_module.weight.view(self.out_features,-1).data.clone() # reshape here
35 | self.ori_bias = None if org_module.bias is None else org_module.bias.data.clone()
36 |
37 | self.act_quant_params = activation_params
38 | if self.act_quant_params is not None:
39 | self.act_quantizer = UniformAffineQuantizer(**self.act_quant_params)
40 |
41 |
42 | # quant lora quant here ===========================
43 | self.quant_lora_weights = True
44 | self.double_inited = False
45 | self.double_quant_delta = torch.nn.Parameter(torch.zeros(self.out_features,1))
46 | self.double_quant_zero_point = torch.nn.Parameter(torch.zeros(self.out_features,1))
47 |
48 | r = rank
49 | self.alpha = 1.5
50 | self.loraA = nn.Linear(org_module.in_features, r, bias=False)
51 | self.loraB = nn.Linear(r, org_module.out_features, bias=False)
52 | nn.init.kaiming_uniform_(self.loraA.weight, a=math.sqrt(5))
53 | nn.init.zeros_(self.loraB.weight)
54 |
55 | # init the auxiliary matrix R in IntLoRA
56 | m = torch.distributions.laplace.Laplace(loc=torch.tensor([0.]),scale=torch.tensor([0.5]))
57 | aux_R = m.sample((org_module.out_features,org_module.in_features))[:,:,0]
58 | self.register_buffer('aux_R', aux_R)
59 | self.aux_R = self.aux_R.to(self.ori_weight.device).detach()
60 |
61 |
62 | def forward(self, input: torch.Tensor):
63 | if self.inited is False:
64 | aux_R_abs_max = torch.minimum(self.aux_R.max(dim=-1,keepdim=True)[0].abs(),self.aux_R.min(dim=-1,keepdim=True)[0].abs()).detach()
65 | ori_weight_abs_max = torch.maximum(self.ori_weight.max(dim=-1,keepdim=True)[0].abs(),self.ori_weight.min(dim=-1,keepdim=True)[0].abs()).detach()
66 | self.aux_R = ((ori_weight_abs_max)**self.alpha/(aux_R_abs_max+1e-8)**self.alpha)*self.aux_R
67 |
68 | ori_weight = self.ori_weight - self.aux_R
69 | delta, zero_point = self.init_quantization_scale(ori_weight, self.channel_wise,self.n_bits,self.sym)
70 | self.register_buffer('weight_quant_delta', delta)
71 | self.register_buffer('weight_quant_zero_point', zero_point)
72 | ori_weight_round = round_ste(ori_weight / self.weight_quant_delta) + self.weight_quant_zero_point
73 | if self.sym:
74 | ori_weight_round = torch.clamp(ori_weight_round, -self.n_levels - 1, self.n_levels)
75 | else:
76 | ori_weight_round = torch.clamp(ori_weight_round, 0, self.n_levels - 1)
77 |
78 | # delete the FP weights and save the int weights
79 | self.register_buffer('ori_weight_round', ori_weight_round) # int weight and keep it intact
80 | self.ori_weight = None
81 | torch.cuda.empty_cache()
82 | self.inited = True
83 |
84 | ori_weight_int = self.ori_weight_round - self.weight_quant_zero_point
85 |
86 | lora_weight = (self.aux_R + (self.loraB.weight @ self.loraA.weight)) / \
87 | torch.where(ori_weight_int == 0, torch.tensor(1).to(ori_weight_int.device), ori_weight_int)
88 | weight_updates = self.weight_quant_delta + lora_weight # broad-cast
89 |
90 |
91 | if self.quant_lora_weights:
92 | if self.double_inited is False:
93 | delta, zero_point = self.init_quantization_scale(weight_updates, True, self.lora_bits)
94 | with torch.no_grad():
95 | self.double_quant_delta.copy_(delta)
96 | self.double_quant_zero_point.copy_(zero_point)
97 | self.double_inited = True
98 |
99 | weight_updates_round = round_ste(weight_updates / self.double_quant_delta) + self.double_quant_zero_point
100 | if self.sym:
101 | weight_updates_round = torch.clamp(weight_updates_round, -self.lora_levels - 1, self.lora_levels)
102 | else:
103 | weight_updates_round = torch.clamp(weight_updates_round, 0, self.lora_levels - 1)
104 |
105 | weight_updates_int = (weight_updates_round - self.double_quant_zero_point)
106 | weight_int_mul = weight_updates_int * ori_weight_int # simulated INT multiply
107 | weight = self.double_quant_delta * weight_int_mul
108 |
109 | else:
110 | weight = weight_updates*ori_weight_int
111 |
112 | bias = self.ori_bias
113 |
114 | # do activation quantization
115 | if self.act_quant_params is not None:
116 | input = self.act_quantizer(input)
117 |
118 | out = self.fwd_func(input, weight, bias, **self.fwd_kwargs)
119 |
120 | return out
121 |
122 |
123 |
124 | def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False, n_bits: int = 8, sym: bool= False):
125 | n_levels = 2 ** n_bits if not sym else 2 ** (n_bits - 1) - 1
126 | delta, zero_point = None, None
127 | if channel_wise:
128 | x_clone = x.clone().detach()
129 | n_channels = x_clone.shape[0]
130 | if len(x.shape) == 4:
131 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0]
132 | elif len(x.shape) == 3:
133 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0]
134 | else:
135 | x_max = x_clone.abs().max(dim=-1)[0]
136 | delta = x_max.clone()
137 | zero_point = x_max.clone()
138 | # determine the scale and zero point channel-by-channel
139 | if 'max' in self.scale_method:
140 | delta, zero_point = batch_max(x_clone.view(n_channels, -1), sym, 2 ** n_bits,
141 | self.always_zero)
142 |
143 | elif 'mse' in self.scale_method:
144 | delta, zero_point = batch_mse(x_clone.view(n_channels, -1), sym, 2 ** n_bits,
145 | self.always_zero)
146 |
147 | if len(x.shape) == 4:
148 | delta = delta.view(-1, 1, 1, 1)
149 | zero_point = zero_point.view(-1, 1, 1, 1)
150 | elif len(x.shape) == 3:
151 | delta = delta.view(-1, 1, 1)
152 | zero_point = zero_point.view(-1, 1, 1)
153 | else:
154 | delta = delta.view(-1, 1)
155 | zero_point = zero_point.view(-1, 1)
156 | else:
157 | if 'max' in self.scale_method:
158 | x_min = min(x.min().item(), 0)
159 | x_max = max(x.max().item(), 0)
160 | if 'scale' in self.scale_method:
161 | x_min = x_min * (n_bits + 2) / 8
162 | x_max = x_max * (n_bits + 2) / 8
163 |
164 | x_absmax = max(abs(x_min), x_max)
165 | if sym:
166 | # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax
167 | delta = x_absmax / n_levels
168 | else:
169 | delta = float(x.max().item() - x.min().item()) / (n_levels - 1)
170 | if delta < 1e-8:
171 | delta = 1e-8
172 |
173 | zero_point = round(-x_min / delta) if not (sym or self.always_zero) else 0
174 | delta = torch.tensor(delta).type_as(x)
175 |
176 | elif self.scale_method == 'mse':
177 | x_max = x.max()
178 | x_min = x.min()
179 | best_score = 1e+10
180 | for i in range(80):
181 | new_max = x_max * (1.0 - (i * 0.01))
182 | new_min = x_min * (1.0 - (i * 0.01))
183 | x_q = self.quantize(x, new_max, new_min,n_bits,sym)
184 | score = lp_loss(x, x_q, p=2.4, reduction='all')
185 | if score < best_score:
186 | best_score = score
187 | delta = (new_max - new_min) / (2 ** n_bits - 1) \
188 | if not self.always_zero else new_max / (2 ** n_bits - 1)
189 | zero_point = (- new_min / delta).round() if not self.always_zero else 0
190 | else:
191 | raise NotImplementedError
192 |
193 | return delta, zero_point
194 |
195 | def quantize(self, x, max, min,n_bits,sym):
196 | n_levels = 2 ** n_bits if not sym else 2 ** (n_bits - 1) - 1
197 | delta = (max - min) / (2 ** n_bits - 1) if not self.always_zero else max / (2 ** n_bits - 1)
198 | zero_point = (- min / delta).round() if not self.always_zero else 0
199 | x_int = torch.round(x / delta)
200 | x_quant = torch.clamp(x_int + zero_point, 0,n_levels - 1)
201 | x_float_q = (x_quant - zero_point) * delta
202 | return x_float_q
203 |
--------------------------------------------------------------------------------
/utils/intlora_shift.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 | from utils.quant_utils import batch_mse, batch_max, lp_loss, round_ste
6 | import numpy as np
7 | from utils.quant_layer import UniformAffineQuantizer
8 |
9 |
10 | def round(x, rounding='deterministic'):
11 | assert (rounding in ['deterministic', 'stochastic'])
12 | if rounding == 'stochastic':
13 | x_floor = x.floor()
14 | return x_floor + torch.bernoulli(x - x_floor)
15 | else:
16 | return x.round()
17 |
18 |
19 | def get_shift_and_sign(x, rounding='deterministic'):
20 | sign = torch.sign(x)
21 |
22 | x_abs = torch.abs(x)
23 | if rounding == "floor":
24 | shift = torch.floor(torch.log(x_abs) / np.log(2))
25 | else:
26 | shift = round_ste(torch.log(x_abs) / np.log(2))
27 |
28 | return shift, sign
29 |
30 |
31 | def round_power_of_2(x, rounding='deterministic', q_bias=None, scale=None):
32 | if q_bias is not None:
33 | q_bias = q_bias.unsqueeze(1).expand_as(x)
34 | x = x - q_bias
35 | if scale is not None:
36 | scale = scale.unsqueeze(1).expand_as(x)
37 | x = x / scale
38 | shift, sign = get_shift_and_sign(x, rounding)
39 | x_rounded = (2.0 ** shift) * sign
40 | if scale is not None:
41 | x_rounded = x_rounded * scale
42 | if q_bias is not None:
43 | x_rounded = x_rounded + q_bias
44 | return x_rounded
45 |
46 |
47 | def additive_power_of_2(x, log_s):
48 | sign = torch.sign(x)
49 | x_abs = torch.abs(x)
50 |
51 | shift = round_ste(torch.log(x_abs) / np.log(2) + log_s)
52 |
53 | x_rounded = (2.0 ** shift) * sign
54 |
55 | return x_rounded
56 |
57 |
58 | class StraightThrough(nn.Module):
59 | def __init__(self, channel_num: int = 1):
60 | super().__init__()
61 |
62 | def forward(self, input):
63 | return input
64 |
65 |
66 |
67 |
68 | class IntLoRA_SHIFT(nn.Module):
69 | # The implementation of our IntLoRA-SHIFT
70 | def __init__(self, org_module, n_bits=8, lora_bits=8, symmetric=False, channel_wise=True, rank=4, activation_params=None):
71 | super(IntLoRA_SHIFT, self).__init__()
72 |
73 | self.n_bits = n_bits
74 | self.lora_bits = lora_bits
75 | self.sym = symmetric
76 | self.scale_method = 'mse'
77 | self.always_zero = False
78 | self.n_levels = 2 ** self.n_bits if not self.sym else 2 ** (self.n_bits - 1) - 1
79 | self.channel_wise = channel_wise
80 | self.inited = False
81 | self.lora_levels = self.n_levels
82 |
83 | self.fwd_kwargs = dict()
84 | self.fwd_func = F.linear
85 | self.in_features = org_module.in_features
86 | self.out_features = org_module.out_features
87 |
88 | # save original weights and bias and keep them intact
89 | self.ori_weight_shape = org_module.weight.shape
90 |
91 | self.ori_weight = org_module.weight.view(self.out_features, -1).data.clone() # reshape here
92 | self.ori_bias = None if org_module.bias is None else org_module.bias.data.clone()
93 |
94 | self.act_quant_params = activation_params
95 | if self.act_quant_params is not None:
96 | self.act_quantizer = UniformAffineQuantizer(**self.act_quant_params)
97 |
98 | # quant lora quant here ===========================
99 | self.quant_lora_weights = True
100 | r= rank
101 | self.alpha = 1.5
102 |
103 | self.loraA = nn.Linear(org_module.in_features, r, bias=False)
104 | self.loraB = nn.Linear(r, org_module.out_features, bias=False)
105 | nn.init.kaiming_uniform_(self.loraA.weight, a=math.sqrt(5))
106 | nn.init.zeros_(self.loraB.weight)
107 |
108 | # init the auxiliary matrix R in IntLoRA
109 | m = torch.distributions.laplace.Laplace(loc=torch.tensor([0.]),scale=torch.tensor([0.5]))
110 | aux_R = m.sample((org_module.out_features,org_module.in_features))[:,:,0]
111 | self.register_buffer('aux_R', aux_R)
112 | self.aux_R = self.aux_R.to(self.ori_weight.device).detach()
113 |
114 |
115 | def forward(self, input: torch.Tensor):
116 | if self.inited is False:
117 | aux_R_abs_max = torch.minimum(self.aux_R.max(dim=-1, keepdim=True)[0].abs(),
118 | self.aux_R.min(dim=-1, keepdim=True)[0].abs()).detach()
119 | ori_weight_abs_max = torch.maximum(self.ori_weight.max(dim=-1, keepdim=True)[0].abs(),
120 | self.ori_weight.min(dim=-1, keepdim=True)[0].abs()).detach()
121 | self.aux_R = ((ori_weight_abs_max) ** self.alpha / (aux_R_abs_max + 1e-8) ** self.alpha) * self.aux_R
122 |
123 | ori_weight = self.ori_weight - self.aux_R
124 | delta, zero_point = self.init_quantization_scale(ori_weight, self.channel_wise, self.n_bits, self.sym)
125 | self.register_buffer('weight_quant_delta', delta)
126 | self.register_buffer('weight_quant_zero_point', zero_point)
127 | ori_weight_round = round_ste(ori_weight / self.weight_quant_delta) + self.weight_quant_zero_point
128 | if self.sym:
129 | ori_weight_round = torch.clamp(ori_weight_round, -self.n_levels - 1, self.n_levels)
130 | else:
131 | ori_weight_round = torch.clamp(ori_weight_round, 0, self.n_levels - 1)
132 |
133 | # delete the FP weights and save the int weights
134 | self.register_buffer('ori_weight_round', ori_weight_round) # int weight and keep it intact
135 | self.ori_weight = None
136 | torch.cuda.empty_cache()
137 | self.inited = True
138 |
139 | ori_weight_int = self.ori_weight_round - self.weight_quant_zero_point
140 |
141 | # PETL for quant scale here ==================================
142 | lora_weight = (self.aux_R + (self.loraB.weight @ self.loraA.weight)) / \
143 | torch.where(ori_weight_int == 0, torch.tensor(1).to(ori_weight_int.device), ori_weight_int)
144 | weight_updates = self.weight_quant_delta + lora_weight # broad-cast
145 |
146 |
147 | if self.quant_lora_weights:
148 | weight_updates_sign = weight_updates.sign()
149 | weight_updates_abs = torch.abs(weight_updates)
150 | lora_shift = round_ste(torch.log2(weight_updates_abs + 1e-16))
151 | lora_rounded = 2.0 ** lora_shift
152 | weight = weight_updates_sign * lora_rounded * ori_weight_int
153 | if torch.any(torch.isnan(weight_updates)):
154 | print('There is nan in the weight-updates for log2 quantization')
155 | raise NotImplementedError
156 |
157 | else:
158 | weight = weight_updates * ori_weight_int
159 |
160 | # do activation quantization
161 | if self.act_quant_params is not None:
162 | input = self.act_quantizer(input)
163 |
164 | bias = self.ori_bias
165 | out = self.fwd_func(input, weight, bias, **self.fwd_kwargs)
166 |
167 | return out
168 |
169 | def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False, n_bits: int = 8, sym: bool = False):
170 | n_levels = 2 ** n_bits if not sym else 2 ** (n_bits - 1) - 1
171 | delta, zero_point = None, None
172 | if channel_wise:
173 | x_clone = x.clone().detach()
174 | n_channels = x_clone.shape[0]
175 | if len(x.shape) == 4:
176 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0]
177 | elif len(x.shape) == 3:
178 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0]
179 | else:
180 | x_max = x_clone.abs().max(dim=-1)[0]
181 | delta = x_max.clone()
182 | zero_point = x_max.clone()
183 | # determine the scale and zero point channel-by-channel
184 | if 'max' in self.scale_method:
185 | delta, zero_point = batch_max(x_clone.view(n_channels, -1), sym, 2 ** n_bits,
186 | self.always_zero)
187 |
188 | elif 'mse' in self.scale_method:
189 | delta, zero_point = batch_mse(x_clone.view(n_channels, -1), sym, 2 ** n_bits,
190 | self.always_zero)
191 |
192 | if len(x.shape) == 4:
193 | delta = delta.view(-1, 1, 1, 1)
194 | zero_point = zero_point.view(-1, 1, 1, 1)
195 | elif len(x.shape) == 3:
196 | delta = delta.view(-1, 1, 1)
197 | zero_point = zero_point.view(-1, 1, 1)
198 | else:
199 | delta = delta.view(-1, 1)
200 | zero_point = zero_point.view(-1, 1)
201 | else:
202 | # if self.leaf_param:
203 | # self.x_min = x.data.min()
204 | # self.x_max = x.data.max()
205 |
206 | if 'max' in self.scale_method:
207 | x_min = min(x.min().item(), 0)
208 | x_max = max(x.max().item(), 0)
209 | if 'scale' in self.scale_method:
210 | x_min = x_min * (n_bits + 2) / 8
211 | x_max = x_max * (n_bits + 2) / 8
212 |
213 | x_absmax = max(abs(x_min), x_max)
214 | if sym:
215 | # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax
216 | delta = x_absmax / n_levels
217 | else:
218 | delta = float(x.max().item() - x.min().item()) / (n_levels - 1)
219 | if delta < 1e-8:
220 | delta = 1e-8
221 |
222 | zero_point = round(-x_min / delta) if not (sym or self.always_zero) else 0
223 | delta = torch.tensor(delta).type_as(x)
224 |
225 | elif self.scale_method == 'mse':
226 | x_max = x.max()
227 | x_min = x.min()
228 | best_score = 1e+10
229 | for i in range(80):
230 | new_max = x_max * (1.0 - (i * 0.01))
231 | new_min = x_min * (1.0 - (i * 0.01))
232 | x_q = self.quantize(x, new_max, new_min, n_bits, sym)
233 | score = lp_loss(x, x_q, p=2.4, reduction='all')
234 | if score < best_score:
235 | best_score = score
236 | delta = (new_max - new_min) / (2 ** n_bits - 1) \
237 | if not self.always_zero else new_max / (2 ** n_bits - 1)
238 | zero_point = (- new_min / delta).round() if not self.always_zero else 0
239 | else:
240 | raise NotImplementedError
241 |
242 | return delta, zero_point
243 |
244 | def quantize(self, x, max, min, n_bits, sym):
245 | n_levels = 2 ** n_bits if not sym else 2 ** (n_bits - 1) - 1
246 | delta = (max - min) / (2 ** n_bits - 1) if not self.always_zero else max / (2 ** n_bits - 1)
247 | zero_point = (- min / delta).round() if not self.always_zero else 0
248 | # we assume weight quantization is always signed
249 | x_int = torch.round(x / delta)
250 | x_quant = torch.clamp(x_int + zero_point, 0, n_levels - 1)
251 | x_float_q = (x_quant - zero_point) * delta
252 | return x_float_q
253 |
--------------------------------------------------------------------------------
/utils/quant_layer.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import warnings
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from typing import Union
7 | from utils.quant_utils import batch_mse,batch_max
8 | logger = logging.getLogger(__name__)
9 |
10 |
11 | class StraightThrough(nn.Module):
12 | def __init__(self, channel_num: int = 1):
13 | super().__init__()
14 |
15 | def forward(self, input):
16 | return input
17 |
18 |
19 | def round_ste(x: torch.Tensor):
20 | """
21 | Implement Straight-Through Estimator for rounding operation.
22 | """
23 | return (x.round() - x).detach() + x
24 |
25 |
26 | def lp_loss(pred, tgt, p=2.0, reduction='none'):
27 | """
28 | loss function measured in L_p Norm
29 | """
30 | if reduction == 'none':
31 | return (pred-tgt).abs().pow(p).sum(1).mean()
32 | else:
33 | return (pred-tgt).abs().pow(p).mean()
34 |
35 |
36 |
37 | class UniformAffineQuantizer(nn.Module):
38 | """
39 | PyTorch Function that can be used for asymmetric quantization (also called uniform affine
40 | quantization). Quantizes its argument in the forward pass, passes the gradient 'straight
41 | through' on the backward pass, ignoring the quantization that occurred.
42 | Based on https://arxiv.org/abs/1806.08342.
43 |
44 | :param n_bits: number of bit for quantization
45 | :param symmetric: if True, the zero_point should always be 0
46 | :param channel_wise: if True, compute scale and zero_point in each channel
47 | :param scale_method: determines the quantization scale and zero point
48 | """
49 |
50 | def __init__(self, n_bits: int = 8, symmetric: bool = False, channel_wise: bool = False, scale_method: str = 'max',
51 | leaf_param: bool = False, always_zero: bool = False, **kwargs):
52 | super(UniformAffineQuantizer, self).__init__()
53 | self.sym = symmetric
54 | # assert 2 <= n_bits <= 8, 'bitwidth not supported'
55 | self.n_bits = n_bits
56 | self.n_levels = 2 ** self.n_bits if not self.sym else 2 ** (self.n_bits - 1) - 1
57 | self.delta = None
58 | self.zero_point = None
59 | self.inited = False
60 | self.leaf_param = leaf_param
61 | self.channel_wise = channel_wise
62 | self.scale_method = scale_method
63 | self.running_stat = False
64 | self.always_zero = always_zero
65 | if self.leaf_param:
66 | self.x_min, self.x_max = None, None
67 |
68 | def forward(self, x: torch.Tensor):
69 | if self.inited is False:
70 | if self.leaf_param:
71 | delta, zero_point = self.init_quantization_scale(x.detach(), self.channel_wise)
72 | self.delta = torch.nn.Parameter(delta)
73 | self.zero_point = torch.nn.Parameter(zero_point)
74 |
75 | else:
76 | self.delta, self.zero_point = self.init_quantization_scale(x, self.channel_wise)
77 | self.inited = True
78 |
79 | # start quantization
80 | # print(f"x shape {x.shape} delta shape {self.delta.shape} zero shape {self.zero_point.shape}")
81 | x_int = round_ste(x / self.delta) + self.zero_point
82 | # x_quant = torch.clamp(x_int, 0, self.n_levels - 1)
83 | if self.sym:
84 | x_quant = torch.clamp(x_int, -self.n_levels - 1, self.n_levels)
85 | else:
86 | x_quant = torch.clamp(x_int, 0, self.n_levels - 1)
87 | x_dequant = (x_quant - self.zero_point) * self.delta
88 | return x_dequant
89 |
90 | def act_momentum_update(self, x: torch.Tensor, act_range_momentum: float = 0.95):
91 | assert (self.inited)
92 | assert (self.leaf_param)
93 |
94 | x_min = x.data.min()
95 | x_max = x.data.max()
96 | self.x_min = self.x_min * act_range_momentum + x_min * (1 - act_range_momentum)
97 | self.x_max = self.x_max * act_range_momentum + x_max * (1 - act_range_momentum)
98 |
99 | if self.sym:
100 | # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax
101 | delta = torch.max(self.x_min.abs(), self.x_max.abs()) / self.n_levels
102 | else:
103 | delta = (self.x_max - self.x_min) / (self.n_levels - 1) if not self.always_zero \
104 | else self.x_max / (self.n_levels - 1)
105 |
106 | delta = torch.clamp(delta, min=1e-8)
107 | if not self.sym:
108 | self.zero_point = (-self.x_min / delta).round() if not (self.sym or self.always_zero) else 0
109 | self.delta = torch.nn.Parameter(delta)
110 |
111 | def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False):
112 | delta, zero_point = None, None
113 | if channel_wise:
114 | x_clone = x.clone().detach()
115 | n_channels = x_clone.shape[0]
116 | if len(x.shape) == 4:
117 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0]
118 | elif len(x.shape) == 3:
119 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0]
120 | else:
121 | x_max = x_clone.abs().max(dim=-1)[0]
122 | delta = x_max.clone()
123 | zero_point = x_max.clone()
124 | # determine the scale and zero point channel-by-channel
125 | if 'max' in self.scale_method:
126 | delta, zero_point = batch_max(x_clone.view(n_channels, -1), self.sym, 2 ** self.n_bits,
127 | self.always_zero)
128 |
129 | elif 'mse' in self.scale_method:
130 | delta, zero_point = batch_mse(x_clone.view(n_channels, -1), self.sym, 2 ** self.n_bits,
131 | self.always_zero)
132 |
133 | if len(x.shape) == 4:
134 | delta = delta.view(-1, 1, 1, 1)
135 | zero_point = zero_point.view(-1, 1, 1, 1)
136 | elif len(x.shape) == 3:
137 | delta = delta.view(-1, 1, 1)
138 | zero_point = zero_point.view(-1, 1, 1)
139 | else:
140 | delta = delta.view(-1, 1)
141 | zero_point = zero_point.view(-1, 1)
142 | else:
143 | if self.leaf_param: # for momentum update
144 | self.x_min = x.data.min()
145 | self.x_max = x.data.max()
146 |
147 | if 'max' in self.scale_method:
148 | x_min = min(x.min().item(), 0)
149 | x_max = max(x.max().item(), 0)
150 | if 'scale' in self.scale_method:
151 | x_min = x_min * (self.n_bits + 2) / 8
152 | x_max = x_max * (self.n_bits + 2) / 8
153 |
154 | x_absmax = max(abs(x_min), x_max)
155 | if self.sym:
156 | # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax
157 | delta = x_absmax / self.n_levels
158 | else:
159 | delta = float(x.max().item() - x.min().item()) / (self.n_levels - 1)
160 | if delta < 1e-8:
161 | warnings.warn('Quantization range close to zero: [{}, {}]'.format(x_min, x_max))
162 | delta = 1e-8
163 |
164 | zero_point = round(-x_min / delta) if not (self.sym or self.always_zero) else 0
165 | delta = torch.tensor(delta).type_as(x)
166 |
167 | elif self.scale_method == 'mse':
168 | x_max = x.max()
169 | x_min = x.min()
170 | best_score = 1e+10
171 | for i in range(80):
172 | new_max = x_max * (1.0 - (i * 0.01))
173 | new_min = x_min * (1.0 - (i * 0.01))
174 | x_q = self.quantize(x, new_max, new_min)
175 | # L_p norm minimization as described in LAPQ
176 | # https://arxiv.org/abs/1911.07190
177 | score = lp_loss(x, x_q, p=2.4, reduction='all')
178 | if score < best_score:
179 | best_score = score
180 | delta = (new_max - new_min) / (2 ** self.n_bits - 1) \
181 | if not self.always_zero else new_max / (2 ** self.n_bits - 1)
182 | zero_point = (- new_min / delta).round() if not self.always_zero else 0
183 | else:
184 | raise NotImplementedError
185 |
186 | return delta, zero_point
187 |
188 | def quantize(self, x, max, min):
189 | delta = (max - min) / (2 ** self.n_bits - 1) if not self.always_zero else max / (2 ** self.n_bits - 1)
190 | zero_point = (- min / delta).round() if not self.always_zero else 0
191 | # we assume weight quantization is always signed
192 | x_int = torch.round(x / delta)
193 | x_quant = torch.clamp(x_int + zero_point, 0, self.n_levels - 1)
194 | x_float_q = (x_quant - zero_point) * delta
195 | return x_float_q
196 |
197 | def bitwidth_refactor(self, refactored_bit: int):
198 | # assert 2 <= refactored_bit <= 8, 'bitwidth not supported'
199 | self.n_bits = refactored_bit
200 | self.n_levels = 2 ** self.n_bits
201 |
202 | def extra_repr(self):
203 | s = 'bit={n_bits}, scale_method={scale_method}, symmetric={sym}, channel_wise={channel_wise},' \
204 | ' leaf_param={leaf_param}'
205 | return s.format(**self.__dict__)
206 |
207 |
208 | class QuantLayerNormal(nn.Module):
209 | """
210 | Quantized Module that can perform quantized convolution or normal convolution.
211 | To activate quantization, please use set_quant_state function.
212 | """
213 | def __init__(self, org_module: Union[nn.Conv2d, nn.Linear, nn.Conv1d], weight_quant_params= None,
214 | activation_params= None,):
215 | super(QuantLayerNormal, self).__init__()
216 |
217 | self.weight_quant_params = weight_quant_params
218 | self.act_quant_params = activation_params
219 |
220 | if isinstance(org_module, nn.Conv2d):
221 | self.fwd_kwargs = dict(stride=org_module.stride, padding=org_module.padding,
222 | dilation=org_module.dilation, groups=org_module.groups)
223 | self.fwd_func = F.conv2d
224 | elif isinstance(org_module, nn.Conv1d):
225 | self.fwd_kwargs = dict(stride=org_module.stride, padding=org_module.padding,
226 | dilation=org_module.dilation, groups=org_module.groups)
227 | self.fwd_func = F.conv1d
228 | else:
229 | self.fwd_kwargs = dict()
230 | self.fwd_func = F.linear
231 | self.weight = org_module.weight
232 | self.org_weight = org_module.weight.data
233 | if org_module.bias is not None:
234 | self.bias = org_module.bias
235 | self.org_bias = org_module.bias.data
236 | else:
237 | self.bias = None
238 | self.org_bias = None
239 | # de-activate the quantized forward default
240 | self.use_weight_quant = self.weight_quant_params is not None
241 | self.use_act_quant = self.act_quant_params is not None
242 |
243 |
244 | # initialize quantizer
245 | self.weight_quantizer = UniformAffineQuantizer(**self.weight_quant_params)
246 | if self.use_act_quant:
247 | self.act_quantizer = UniformAffineQuantizer(**self.act_quant_params)
248 | self.activation_function = StraightThrough()
249 | self.extra_repr = org_module.extra_repr
250 |
251 | def forward(self, input: torch.Tensor):
252 | if self.use_act_quant: # Activation Quant
253 | input = self.act_quantizer(input)
254 |
255 | if self.use_weight_quant: # Weight Quant
256 | with torch.no_grad():
257 | weight = self.weight_quantizer(self.weight)
258 | bias = self.bias
259 | else:
260 | weight = self.org_weight
261 | bias = self.org_bias
262 | out = self.fwd_func(input, weight, bias, **self.fwd_kwargs)
263 | out = self.activation_function(out)
264 |
265 | return out
266 |
--------------------------------------------------------------------------------
/utils/quant_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from utils.intlora_mul import IntLoRA_MUL
5 | from utils.intlora_shift import IntLoRA_SHIFT
6 | from utils.quant_layer import QuantLayerNormal
7 |
8 |
9 |
10 | class StraightThrough(nn.Module):
11 | def __init__(self, channel_num: int = 1):
12 | super().__init__()
13 |
14 | def forward(self, input):
15 | return input
16 |
17 |
18 |
19 | def check_in_special(ss, special_list):
20 | for itm in special_list:
21 | if itm in ss:
22 | return True
23 | return False
24 |
25 |
26 |
27 | class QuantUnetWarp(nn.Module):
28 | def __init__(self, model: nn.Module, args):
29 | super().__init__()
30 | self.model = model
31 | self.intlora_type = args.intlora
32 | lora_quant_params = {'n_bits': args.nbits, 'lora_bits':args.nbits, 'symmetric':False, 'channel_wise':True, 'rank':args.rank}
33 | other_quant_params = {'n_bits': args.nbits, 'symmetric': False, 'channel_wise': True, 'scale_method': 'mse'}
34 | activation_quant_params = {'n_bits': args.act_nbits, 'symmetric': False, 'channel_wise': False, 'scale_method': 'mse','leaf_param': True} if args.use_activation_quant else None
35 | special_list = ['to_q','to_k','to_v','to_out']
36 | assert self.intlora_type in ['MUL','SHIFT']
37 | self.IntLoRALayer = IntLoRA_MUL if self.intlora_type == 'MUL' else IntLoRA_SHIFT
38 |
39 | self.quant_module_refactor(self.model, lora_quant_params, other_quant_params, activation_quant_params, special_list)
40 |
41 |
42 | def quant_module_refactor(self, module: nn.Module, lora_quant_params, other_quant_params, activation_quant_params, sepcial_list, prev_name=''):
43 | for name, child_module in module.named_children():
44 | tmp_name = prev_name+'_'+name
45 | if isinstance(child_module, (nn.Conv2d, nn.Conv1d, nn.Linear)) and not ('downsample' in name and 'conv' in name):
46 | if check_in_special(tmp_name,sepcial_list):
47 | setattr(module, name, self.IntLoRALayer(child_module, **lora_quant_params, activation_params=activation_quant_params))
48 | else:
49 | setattr(module, name, QuantLayerNormal(child_module, other_quant_params,activation_params=activation_quant_params))
50 | elif isinstance(child_module, StraightThrough):
51 | continue
52 | else:
53 | self.quant_module_refactor(child_module, lora_quant_params, other_quant_params,activation_quant_params, sepcial_list, prev_name=tmp_name)
54 |
55 |
56 | def forward(self, image, t, context=None):
57 | return self.model(image, t, context)
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/utils/quant_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch import nn
4 | from diffusers.utils import deprecate, logging
5 | from diffusers.utils.import_utils import is_xformers_available
6 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
7 |
8 | if is_xformers_available():
9 | import xformers
10 | import xformers.ops
11 | else:
12 | xformers = None
13 |
14 |
15 | def batch_mse(x: torch.Tensor,
16 | symmetric: bool = False,
17 | level: int = 256,
18 | always_zero: bool = False
19 | ) -> [torch.Tensor, torch.Tensor]:
20 | x_min, x_max = torch.min(x, dim=-1, keepdim=True)[0], torch.max(x, dim=-1, keepdim=True)[0] # [d_out]
21 | delta, zero_point = torch.zeros_like(x_min), torch.zeros_like(x_min)
22 | s = torch.full((x.shape[0], 1), 30000, dtype=x.dtype, device=x.device)
23 | for i in range(80):
24 | new_min = x_min * (1. - (i * 0.01))
25 | new_max = x_max * (1. - (i * 0.01))
26 | new_delta = (new_max - new_min) / (level - 1)
27 | new_zero_point = torch.round(-new_min / new_delta) if not (symmetric or always_zero) else torch.zeros_like(new_delta)
28 | NB, PB = -level // 2 if symmetric and not always_zero else 0, \
29 | level // 2 - 1 if symmetric and not always_zero else level - 1
30 | x_q = torch.clamp(torch.round(x / new_delta) + new_zero_point, NB, PB)
31 | x_dq = new_delta * (x_q - new_zero_point)
32 | new_s = (x_dq - x).abs().pow(2.4).mean(dim=-1, keepdim=True)
33 |
34 | update_mask = new_s < s
35 | delta[update_mask] = new_delta[update_mask]
36 | zero_point[update_mask] = new_zero_point[update_mask]
37 | s[update_mask] = new_s[update_mask]
38 |
39 | return delta.squeeze(), zero_point.squeeze()
40 |
41 |
42 | def batch_max(x: torch.Tensor,
43 | symmetric: bool = False,
44 | level: int = 256,
45 | always_zero: bool = False
46 | ) -> [torch.Tensor, torch.Tensor]:
47 | x_min, x_max = torch.min(x, dim=-1, keepdim=True)[0], torch.max(x, dim=-1, keepdim=True)[0] # [d_out]
48 |
49 | x_absmax = torch.max(torch.cat([x_min.abs(), x_max], dim=-1), dim=-1, keepdim=True)[0]
50 |
51 | if symmetric:
52 | # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax
53 | delta = x_absmax / level
54 | else:
55 | delta = (x_max - x_min) / (level - 1)
56 |
57 | delta = torch.clamp(delta, min=1e-8)
58 | zero_point = torch.round(-x_min / delta) if not (symmetric or always_zero) else 0
59 |
60 | return delta.squeeze(), zero_point.squeeze()
61 |
62 |
63 |
64 | def round_ste(x: torch.Tensor):
65 | return (x.round() - x).detach() + x
66 |
67 |
68 |
69 | def lp_loss(pred, tgt, p=2.0, reduction='none'):
70 | if reduction == 'none':
71 | return (pred-tgt).abs().pow(p).sum(1).mean()
72 | else:
73 | return (pred-tgt).abs().pow(p).mean()
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------