├── README.md ├── assets ├── motivation.png └── pipeline.png ├── environment.yaml ├── evaluation.py ├── get_results.py ├── train_dreambooth_quant.py ├── train_dreambooth_quant.sh └── utils ├── intlora_mul.py ├── intlora_shift.py ├── quant_layer.py ├── quant_model.py └── quant_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # IntLoRA: Integral Low-rank Adaptation of Quantized Diffusion Models 2 | 3 | ### [[Paper](https://arxiv.org/pdf/2410.21759)] 4 | 5 | ### Tune your customized diffusion model on ONE 3090 GPU! 6 | 7 | [Hang Guo](https://csguoh.github.io/)1 | [Yawei Li](https://yaweili.bitbucket.io/)2 | [Tao Dai](https://scholar.google.com/citations?user=MqJNdaAAAAAJ&hl=zh-CN)3| [Shu-Tao Xia](https://scholar.google.com.hk/citations?user=koAXTXgAAAAJ&hl=zh-CN)1,4| [Luca Benini](https://scholar.google.com.hk/citations?user=8riq3sYAAAAJ&hl=zh-CN&oi=ao)2 8 | 9 | 1Tsinghua University, 2ETH Zurich, 3Shenzhen University, 4Pengcheng Laboratory 10 | 11 | 12 | 13 | 14 | :star: If our IntLoRA is helpful to your images or projects, please help star this repo. Thanks! :hugs: 15 | 16 | 17 | ## TL; DR 18 | > Our IntLoRA offers three key advantages: (i) for fine-tuning, the pre-trained weights are quantized, reducing memory usage; (ii) for storage, both pre-trained and low-rank weights are in INT which consumes less disk space; (iii) for inference, IntLoRA weights can be naturally merged into quantized pre-trained weights through efficient integer multiplication or bit-shifting, eliminating additional post-training quantization. 19 | 20 |

21 | 22 |

23 | 24 | 25 | ## 🔎 Overview framework 26 | 27 |

28 | 29 |

30 | 31 | 32 | 33 | ## ⚙️ Dependencies and Installation 34 | 35 | ### Step-1 Download and Environment 36 | 37 | ``` 38 | ## git clone this repository 39 | git clone https://github.com/csguoh/IntLoRA.git 40 | cd ./IntLoRA 41 | ``` 42 | 43 | ``` 44 | # create a conda environment 45 | conda env create -f environment.yaml 46 | conda activate intlora 47 | ``` 48 | 49 | ### Step-2 Preparation 50 | 51 | - This code repository contains [Dreambooth](https://arxiv.org/abs/2208.12242) fine-tuning using our IntloRA. One can download the subject driven generation datasets [here](https://github.com/google/dreambooth/tree/main/dataset). 52 | 53 | - You also need to download the pre-trained model weights which will be fine-tuned with our IntLoRA. Here, we use the [Stable Diffusion-1.5](https://huggingface.co/CompVis) as a example. 54 | 55 | 56 | ### Step-3 Fine-tuning! 57 | 58 | - The main file of the fine-tuning is defined in the `train_dreambooth_quant.py`. We have also give the off-the-shelf configuration bash file for you. Thus, one can directly train customized diffusion models with the following command. 59 | 60 | ``` 61 | bash ./train_dreambooth_quant.sh 62 | ``` 63 | 64 | - The following are some key parameters that you may want to modify. 65 | 66 | - `rank`: the inner rank of the LoRA adapter 67 | - `intlora`: one can choose 'MUL' to use our InrLoRA-MUL or 'SHIFT' to use our InrLoRA-SHIFT 68 | - `nbits`: the number of bits of the weight quantization bits 69 | - `use_activation_quant`: whether to use the activation quantization 70 | - `act_nbits`: the activation bits of the activation quantization 71 | - `gradient_checkpointing `: whether to use the gradient checking to further reduce the GPU cost. 72 | 73 | - After run the fine-tunning command above, you can find the generated results in the `./log_quant` file fold. 74 | 75 | ## 😄 Evaluation 76 | - After generate the images, you can test the quality of each generated image using the following command: 77 | 78 | ``` 79 | python evaluation.py 80 | ``` 81 | 82 | - It will generate a `.json` file which contains the IQA results of each subject. Then we can obtain the overall evaluation result with 83 | 84 | ``` 85 | python get_results.py 86 | ``` 87 | 88 | 89 | 90 | ## 🎓Citations 91 | If our code helps your research or work, please consider citing our paper. 92 | The following are BibTeX references: 93 | 94 | 95 | ``` 96 | @article{guo2024intlora, 97 | title={IntLoRA: Integral Low-rank Adaptation of Quantized Diffusion Models}, 98 | author={Guo, Hang and Li, Yawei and Dai, Tao and Xia, Shu-Tao and Benini, Luca}, 99 | journal={arXiv preprint arXiv:2410.21759}, 100 | year={2024} 101 | } 102 | ``` 103 | 104 | 105 | 106 | -------------------------------------------------------------------------------- /assets/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/IntLoRA/65a8257a4311e0feab9b33477c9748d13d8ce17b/assets/motivation.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/IntLoRA/65a8257a4311e0feab9b33477c9748d13d8ce17b/assets/pipeline.png -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: intlora 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - brotli-python=1.0.9=py39h6a678d5_7 11 | - bzip2=1.0.8=h7b6447c_0 12 | - ca-certificates=2024.7.2=h06a4308_0 13 | - cffi=1.16.0=py39h5eee18b_0 14 | - cryptography=41.0.7=py39hdda0065_0 15 | - cuda-cudart=11.8.89=0 16 | - cuda-cupti=11.8.87=0 17 | - cuda-libraries=11.8.0=0 18 | - cuda-nvrtc=11.8.89=0 19 | - cuda-nvtx=11.8.86=0 20 | - cuda-runtime=11.8.0=0 21 | - ffmpeg=4.3=hf484d3e_0 22 | - filelock=3.13.1=py39h06a4308_0 23 | - freetype=2.12.1=h4a9f257_0 24 | - giflib=5.2.1=h5eee18b_3 25 | - gmp=6.2.1=h295c915_3 26 | - gmpy2=2.1.2=py39heeb90bb_0 27 | - gnutls=3.6.15=he1e5248_0 28 | - idna=3.4=py39h06a4308_0 29 | - intel-openmp=2023.1.0=hdb19cb5_46306 30 | - jinja2=3.1.2=py39h06a4308_0 31 | - jpeg=9e=h5eee18b_1 32 | - lame=3.100=h7b6447c_0 33 | - lcms2=2.12=h3be6417_0 34 | - ld_impl_linux-64=2.38=h1181459_1 35 | - lerc=3.0=h295c915_0 36 | - libcublas=11.11.3.6=0 37 | - libcufft=10.9.0.58=0 38 | - libcufile=1.8.1.2=0 39 | - libcurand=10.3.4.101=0 40 | - libcusolver=11.4.1.48=0 41 | - libcusparse=11.7.5.86=0 42 | - libdeflate=1.17=h5eee18b_1 43 | - libffi=3.4.4=h6a678d5_0 44 | - libgcc-ng=11.2.0=h1234567_1 45 | - libgomp=11.2.0=h1234567_1 46 | - libiconv=1.16=h7f8727e_2 47 | - libidn2=2.3.4=h5eee18b_0 48 | - libjpeg-turbo=2.0.0=h9bf148f_0 49 | - libnpp=11.8.0.86=0 50 | - libnvjpeg=11.9.0.86=0 51 | - libpng=1.6.39=h5eee18b_0 52 | - libstdcxx-ng=11.2.0=h1234567_1 53 | - libtasn1=4.19.0=h5eee18b_0 54 | - libtiff=4.5.1=h6a678d5_0 55 | - libunistring=0.9.10=h27cfd23_0 56 | - libwebp=1.3.2=h11a3e52_0 57 | - libwebp-base=1.3.2=h5eee18b_0 58 | - llvm-openmp=14.0.6=h9e868ea_0 59 | - lz4-c=1.9.4=h6a678d5_0 60 | - mkl=2023.1.0=h213fc3f_46344 61 | - mkl-service=2.4.0=py39h5eee18b_1 62 | - mkl_fft=1.3.8=py39h5eee18b_0 63 | - mkl_random=1.2.4=py39hdb19cb5_0 64 | - mpc=1.1.0=h10f8cd9_1 65 | - mpfr=4.0.2=hb69a4c5_1 66 | - mpmath=1.3.0=py39h06a4308_0 67 | - ncurses=6.4=h6a678d5_0 68 | - nettle=3.7.3=hbbd107a_1 69 | - numpy-base=1.26.2=py39hb5e798b_0 70 | - openh264=2.1.1=h4ff587b_0 71 | - openjpeg=2.4.0=h3ad879b_0 72 | - openssl=3.0.14=h5eee18b_0 73 | - pip=23.3=py39h06a4308_0 74 | - pycparser=2.21=pyhd3eb1b0_0 75 | - pyopenssl=23.2.0=py39h06a4308_0 76 | - pysocks=1.7.1=py39h06a4308_0 77 | - python=3.9.18=h955ad1f_0 78 | - pytorch=2.1.0=py3.9_cuda11.8_cudnn8.7.0_0 79 | - pytorch-cuda=11.8=h7e8668a_5 80 | - pytorch-mutex=1.0=cuda 81 | - pyyaml=6.0.1=py39h5eee18b_0 82 | - readline=8.2=h5eee18b_0 83 | - requests=2.31.0=py39h06a4308_0 84 | - setuptools=68.0.0=py39h06a4308_0 85 | - sqlite=3.41.2=h5eee18b_0 86 | - tbb=2021.8.0=hdb19cb5_0 87 | - tk=8.6.12=h1ccaba5_0 88 | - torchaudio=2.1.0=py39_cu118 89 | - torchtriton=2.1.0=py39 90 | - typing_extensions=4.7.1=py39h06a4308_0 91 | - wheel=0.41.2=py39h06a4308_0 92 | - xz=5.4.2=h5eee18b_0 93 | - yaml=0.2.5=h7b6447c_0 94 | - zlib=1.2.13=h5eee18b_0 95 | - zstd=1.5.5=hc292b87_0 96 | - pip: 97 | - absl-py==2.1.0 98 | - accelerate==0.25.0 99 | - addict==2.4.0 100 | - aiofiles==23.2.1 101 | - aiohttp==3.8.6 102 | - aiosignal==1.3.1 103 | - annotated-types==0.7.0 104 | - antlr4-python3-runtime==4.9.3 105 | - anyio==3.7.1 106 | - asttokens==2.4.1 107 | - async-timeout==4.0.3 108 | - attrs==23.1.0 109 | - axial-positional-embedding==0.2.1 110 | - bitsandbytes==0.43.1 111 | - cachetools==5.3.2 112 | - certifi==2023.7.22 113 | - chardet==5.2.0 114 | - charset-normalizer==3.3.1 115 | - click==8.1.7 116 | - colorama==0.4.6 117 | - contourpy==1.1.1 118 | - cycler==0.12.1 119 | - dataclasses-json==0.6.1 120 | - deprecated==1.2.14 121 | - diffusers==0.26.0 122 | - docker-pycreds==0.4.0 123 | - einops==0.4.0 124 | - exceptiongroup==1.1.3 125 | - executing==2.0.1 126 | - facexlib==0.3.0 127 | - fastapi==0.112.2 128 | - ffmpy==0.4.0 129 | - filterpy==1.4.5 130 | - fonttools==4.43.1 131 | - frozenlist==1.4.0 132 | - fsspec==2023.10.0 133 | - ftfy==6.2.3 134 | - future==1.0.0 135 | - gitdb==4.0.11 136 | - gitpython==3.1.43 137 | - google-auth==2.26.2 138 | - google-auth-oauthlib==1.2.0 139 | - gradio==4.42.0 140 | - gradio-client==1.3.0 141 | - greenlet==3.0.1 142 | - grpcio==1.59.2 143 | - grpcio-tools==1.59.2 144 | - h11==0.14.0 145 | - h2==4.1.0 146 | - hpack==4.0.0 147 | - httpcore==0.18.0 148 | - httpx==0.25.0 149 | - huggingface-hub==0.24.0 150 | - hyperframe==6.0.1 151 | - icecream==2.1.3 152 | - imageio==2.35.1 153 | - imgaug==0.4.0 154 | - importlib-metadata==7.0.1 155 | - importlib-resources==6.1.0 156 | - joblib==1.3.2 157 | - jsonpatch==1.33 158 | - jsonpointer==2.4 159 | - kiwisolver==1.4.5 160 | - langchain==0.0.327 161 | - langsmith==0.0.54 162 | - lazy-loader==0.4 163 | - lightning-utilities==0.11.6 164 | - llvmlite==0.41.1 165 | - lmdb==1.5.1 166 | - local-attention==1.4.4 167 | - lpips==0.1.4 168 | - markdown==3.5.2 169 | - markdown-it-py==3.0.0 170 | - markupsafe==2.1.3 171 | - marshmallow==3.20.1 172 | - matplotlib==3.7.0 173 | - mdurl==0.1.2 174 | - multidict==6.0.4 175 | - mypy-extensions==1.0.0 176 | - networkx==3.2.1 177 | - nltk==3.8.1 178 | - numba==0.58.1 179 | - numpy==1.23.5 180 | - nvidia-cublas-cu11==11.10.3.66 181 | - nvidia-cublas-cu12==12.1.3.1 182 | - nvidia-cuda-cupti-cu12==12.1.105 183 | - nvidia-cuda-nvrtc-cu11==11.7.99 184 | - nvidia-cuda-nvrtc-cu12==12.1.105 185 | - nvidia-cuda-runtime-cu11==11.7.99 186 | - nvidia-cuda-runtime-cu12==12.1.105 187 | - nvidia-cudnn-cu11==8.5.0.96 188 | - nvidia-cudnn-cu12==8.9.2.26 189 | - nvidia-cufft-cu12==11.0.2.54 190 | - nvidia-curand-cu12==10.3.2.106 191 | - nvidia-cusolver-cu12==11.4.5.107 192 | - nvidia-cusparse-cu12==12.1.0.106 193 | - nvidia-nccl-cu12==2.18.1 194 | - nvidia-nvjitlink-cu12==12.3.52 195 | - nvidia-nvtx-cu12==12.1.105 196 | - oauthlib==3.2.2 197 | - omegaconf==2.3.0 198 | - open-clip-torch==2.26.1 199 | - openai==0.28.1 200 | - openai-clip==1.0.1 201 | - opencv-python==4.9.0.80 202 | - opencv-python-headless==4.10.0.84 203 | - orjson==3.10.7 204 | - packaging==23.2 205 | - pandas==1.5.3 206 | - patool==1.12 207 | - patsy==0.5.4 208 | - pdfminer==20191125 209 | - peft==0.7.1 210 | - pillow==10.1.0 211 | - platformdirs==4.2.2 212 | - product-key-memory==0.1.10 213 | - protobuf==4.23.4 214 | - psutil==5.9.6 215 | - pyarrow==13.0.0 216 | - pyasn1==0.5.1 217 | - pyasn1-modules==0.3.0 218 | - pycryptodome==3.19.0 219 | - pydantic==2.8.2 220 | - pydantic-core==2.20.1 221 | - pydub==0.25.1 222 | - pygments==2.18.0 223 | - pyiqa==0.1.10 224 | - pyparsing==3.1.1 225 | - python-dateutil==2.8.2 226 | - python-dotenv==1.0.0 227 | - python-multipart==0.0.9 228 | - pytorch-lightning==2.1.0 229 | - pytz==2023.3.post1 230 | - qdrant-client==1.1.1 231 | - reformer-pytorch==1.4.4 232 | - regex==2023.10.3 233 | - requests-oauthlib==1.3.1 234 | - rich==13.7.1 235 | - rsa==4.9 236 | - ruff==0.6.2 237 | - safetensors==0.4.4 238 | - scikit-base==0.6.1 239 | - scikit-image==0.24.0 240 | - scikit-learn==1.2.2 241 | - scipy==1.10.1 242 | - seaborn==0.13.2 243 | - semantic-version==2.10.0 244 | - sentence-transformers==2.2.2 245 | - sentencepiece==0.1.99 246 | - sentry-sdk==2.10.0 247 | - setproctitle==1.3.3 248 | - shapely==2.0.6 249 | - shellingham==1.5.4 250 | - six==1.16.0 251 | - sktime==0.16.1 252 | - smmap==5.0.1 253 | - sniffio==1.3.0 254 | - sqlalchemy==2.0.22 255 | - starlette==0.38.2 256 | - statsmodels==0.14.1 257 | - sympy==1.11.1 258 | - tenacity==8.2.3 259 | - tensorboard==2.15.1 260 | - tensorboard-data-server==0.7.2 261 | - threadpoolctl==3.2.0 262 | - tifffile==2024.8.24 263 | - timm==0.9.7 264 | - tokenizers==0.15.2 265 | - tomli==2.0.1 266 | - tomlkit==0.12.0 267 | - torch==1.13.0 268 | - torchmetrics==1.4.1 269 | - torchvision==0.16.0 270 | - tqdm==4.64.1 271 | - transformers==4.37.2 272 | - triton==2.1.0 273 | - typer==0.12.5 274 | - typing-extensions==4.12.2 275 | - typing-inspect==0.8.0 276 | - tzdata==2023.3 277 | - urllib3==2.2.2 278 | - uvicorn==0.21.1 279 | - wandb==0.17.5 280 | - wcwidth==0.2.13 281 | - websockets==12.0 282 | - werkzeug==3.0.1 283 | - wrapt==1.16.0 284 | - yapf==0.40.2 285 | - yarl==1.9.2 286 | - zipp==3.17.0 287 | 288 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import hashlib 18 | import logging 19 | import math 20 | import os 21 | import warnings 22 | from pathlib import Path 23 | 24 | from functools import reduce 25 | import numpy as np 26 | import torch 27 | import torch.nn.functional as F 28 | import torch.utils.checkpoint 29 | import transformers 30 | from packaging import version 31 | from PIL import Image 32 | from torch.utils.data import Dataset, DataLoader 33 | from torchvision import transforms 34 | from tqdm.auto import tqdm 35 | from transformers import AutoTokenizer, PretrainedConfig, ViTFeatureExtractor, ViTModel 36 | 37 | import lpips 38 | import json 39 | from PIL import Image 40 | import requests 41 | from transformers import AutoProcessor, AutoTokenizer, CLIPModel 42 | import torchvision.transforms.functional as TF 43 | from torch.nn.functional import cosine_similarity 44 | from torchvision.transforms import Compose, ToTensor, Normalize, Resize, ToPILImage 45 | 46 | 47 | 48 | def get_prompt(subject_name, prompt_idx): 49 | subject_names = [ 50 | "backpack", "backpack_dog", "bear_plushie", "berry_bowl", "can", 51 | "candle", "cat", "cat2", "clock", "colorful_sneaker", 52 | "dog", "dog2", "dog3", "dog5", "dog6", 53 | "dog7", "dog8", "duck_toy", "fancy_boot", "grey_sloth_plushie", 54 | "monster_toy", "pink_sunglasses", "poop_emoji", "rc_car", "red_cartoon", 55 | "robot_toy", "shiny_sneaker", "teapot", "vase", "wolf_plushie", 56 | ] 57 | 58 | class_tokens = [ 59 | "backpack", "backpack", "stuffed animal", "bowl", "can", 60 | "candle", "cat", "cat", "clock", "sneaker", 61 | "dog", "dog", "dog", "dog", "dog", 62 | "dog", "dog", "toy", "boot", "stuffed animal", 63 | "toy", "glasses", "toy", "toy", "cartoon", 64 | "toy", "sneaker", "teapot", "vase", "stuffed animal", 65 | ] 66 | 67 | class_token = class_tokens[subject_names.index(subject_name)] 68 | 69 | prompt_list = [ 70 | f"a qwe {class_token} in the jungle", 71 | f"a qwe {class_token} in the snow", 72 | f"a qwe {class_token} on the beach", 73 | f"a qwe {class_token} on a cobblestone street", 74 | f"a qwe {class_token} on top of pink fabric", 75 | f"a qwe {class_token} on top of a wooden floor", 76 | f"a qwe {class_token} with a city in the background", 77 | f"a qwe {class_token} with a mountain in the background", 78 | f"a qwe {class_token} with a blue house in the background", 79 | f"a qwe {class_token} on top of a purple rug in a forest", 80 | f"a qwe {class_token} wearing a red hat", 81 | f"a qwe {class_token} wearing a santa hat", 82 | f"a qwe {class_token} wearing a rainbow scarf", 83 | f"a qwe {class_token} wearing a black top hat and a monocle", 84 | f"a qwe {class_token} in a chef outfit", 85 | f"a qwe {class_token} in a firefighter outfit", 86 | f"a qwe {class_token} in a police outfit", 87 | f"a qwe {class_token} wearing pink glasses", 88 | f"a qwe {class_token} wearing a yellow shirt", 89 | f"a qwe {class_token} in a purple wizard outfit", 90 | f"a red qwe {class_token}", 91 | f"a purple qwe {class_token}", 92 | f"a shiny qwe {class_token}", 93 | f"a wet qwe {class_token}", 94 | f"a cube shaped qwe {class_token}", 95 | ] 96 | 97 | return prompt_list[int(prompt_idx)] 98 | 99 | 100 | class PromptDatasetCLIP(Dataset): 101 | def __init__(self, subject_name, data_dir_B, tokenizer, processor, epoch=None): 102 | self.data_dir_B = data_dir_B 103 | 104 | subject_name, prompt_idx = subject_name.split('-') 105 | 106 | data_dir_B = os.path.join(self.data_dir_B, str(epoch)) 107 | self.image_lst = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")] 108 | self.prompt_lst = [get_prompt(subject_name, prompt_idx)] * len(self.image_lst) 109 | 110 | self.tokenizer = tokenizer 111 | self.processor = processor 112 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 113 | 114 | def __len__(self): 115 | return len(self.image_lst) 116 | 117 | def __getitem__(self, idx): 118 | image_path = self.image_lst[idx] 119 | image = Image.open(image_path) 120 | prompt = self.prompt_lst[idx] 121 | 122 | extrema = image.getextrema() 123 | if all(min_val == max_val == 0 for min_val, max_val in extrema): 124 | return None, None 125 | else: 126 | prompt_inputs = self.tokenizer([prompt], padding=True, return_tensors="pt") 127 | image_inputs = self.processor(images=image, return_tensors="pt") 128 | 129 | return image_inputs, prompt_inputs 130 | 131 | 132 | class PairwiseImageDatasetCLIP(Dataset): 133 | def __init__(self, subject_name, data_dir_A, data_dir_B, processor, epoch): 134 | self.data_dir_A = data_dir_A 135 | self.data_dir_B = data_dir_B 136 | 137 | subject_name, prompt_idx = subject_name.split('-') 138 | 139 | self.data_dir_A = os.path.join(self.data_dir_A, subject_name) 140 | self.image_files_A = [os.path.join(self.data_dir_A, f) for f in os.listdir(self.data_dir_A) if 141 | f.endswith(".jpg")] 142 | 143 | data_dir_B = os.path.join(self.data_dir_B, str(epoch)) 144 | self.image_files_B = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")] 145 | 146 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 147 | self.processor = processor 148 | 149 | def __len__(self): 150 | return len(self.image_files_A) * len(self.image_files_B) 151 | 152 | def __getitem__(self, index): 153 | index_A = index // len(self.image_files_B) 154 | index_B = index % len(self.image_files_B) 155 | 156 | image_A = Image.open(self.image_files_A[index_A]) # .convert("RGB") 157 | image_B = Image.open(self.image_files_B[index_B]) # .convert("RGB") 158 | 159 | extrema_A = image_A.getextrema() 160 | extrema_B = image_B.getextrema() 161 | if all(min_val == max_val == 0 for min_val, max_val in extrema_A) or all( 162 | min_val == max_val == 0 for min_val, max_val in extrema_B): 163 | return None, None 164 | else: 165 | inputs_A = self.processor(images=image_A, return_tensors="pt") 166 | inputs_B = self.processor(images=image_B, return_tensors="pt") 167 | 168 | return inputs_A, inputs_B 169 | 170 | 171 | class PairwiseImageDatasetDINO(Dataset): 172 | def __init__(self, subject_name, data_dir_A, data_dir_B, feature_extractor, epoch): 173 | self.data_dir_A = data_dir_A 174 | self.data_dir_B = data_dir_B 175 | 176 | subject_name, prompt_idx = subject_name.split('-') 177 | 178 | self.data_dir_A = os.path.join(self.data_dir_A, subject_name) 179 | self.image_files_A = [os.path.join(self.data_dir_A, f) for f in os.listdir(self.data_dir_A) if 180 | f.endswith(".jpg")] 181 | 182 | data_dir_B = os.path.join(self.data_dir_B, str(epoch)) 183 | self.image_files_B = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")] 184 | 185 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 186 | self.feature_extractor = feature_extractor 187 | 188 | def __len__(self): 189 | return len(self.image_files_A) * len(self.image_files_B) 190 | 191 | def __getitem__(self, index): 192 | index_A = index // len(self.image_files_B) 193 | index_B = index % len(self.image_files_B) 194 | 195 | image_A = Image.open(self.image_files_A[index_A]) # .convert("RGB") 196 | image_B = Image.open(self.image_files_B[index_B]) # .convert("RGB") 197 | 198 | extrema_A = image_A.getextrema() 199 | extrema_B = image_B.getextrema() 200 | if all(min_val == max_val == 0 for min_val, max_val in extrema_A) or all( 201 | min_val == max_val == 0 for min_val, max_val in extrema_B): 202 | return None, None 203 | else: 204 | inputs_A = self.feature_extractor(images=image_A, return_tensors="pt") 205 | inputs_B = self.feature_extractor(images=image_B, return_tensors="pt") 206 | 207 | return inputs_A, inputs_B 208 | 209 | 210 | class PairwiseImageDatasetLPIPS(Dataset): 211 | def __init__(self, subject_name, data_dir_A, data_dir_B, epoch): 212 | self.data_dir_A = data_dir_A 213 | self.data_dir_B = data_dir_B 214 | 215 | subject_name, prompt_idx = subject_name.split('-') 216 | 217 | self.data_dir_A = os.path.join(self.data_dir_A, subject_name) 218 | self.image_files_A = [os.path.join(self.data_dir_A, f) for f in os.listdir(self.data_dir_A) if 219 | f.endswith(".jpg")] 220 | 221 | data_dir_B = os.path.join(self.data_dir_B, str(epoch)) 222 | self.image_files_B = [os.path.join(data_dir_B, f) for f in os.listdir(data_dir_B) if f.endswith(".png")] 223 | 224 | self.transform = Compose([ 225 | Resize((512, 512)), 226 | ToTensor(), 227 | Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 228 | ]) 229 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 230 | 231 | def __len__(self): 232 | return len(self.image_files_A) * len(self.image_files_B) 233 | 234 | def __getitem__(self, index): 235 | index_A = index // len(self.image_files_B) 236 | index_B = index % len(self.image_files_B) 237 | 238 | image_A = Image.open(self.image_files_A[index_A]) # .convert("RGB") 239 | image_B = Image.open(self.image_files_B[index_B]) # .convert("RGB") 240 | 241 | extrema_A = image_A.getextrema() 242 | extrema_B = image_B.getextrema() 243 | if all(min_val == max_val == 0 for min_val, max_val in extrema_A) or all( 244 | min_val == max_val == 0 for min_val, max_val in extrema_B): 245 | return None, None 246 | else: 247 | if self.transform: 248 | image_A = self.transform(image_A) 249 | image_B = self.transform(image_B) 250 | 251 | return image_A, image_B 252 | 253 | 254 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 255 | clip_text_model = CLIPModel.from_pretrained("/data2/xxx/pretrained/clip-vit-large-patch14").to(device) 256 | clip_text_tokenizer = AutoTokenizer.from_pretrained("/data2/xxx/pretrained/clip-vit-large-patch14") 257 | clip_text_processor = AutoProcessor.from_pretrained("/data2/xxx/pretrained/clip-vit-large-patch14") 258 | 259 | def clip_text(subject_name, image_dir): 260 | criterion = 'clip_text' 261 | 262 | model = clip_text_model 263 | # Get the text features 264 | tokenizer = clip_text_tokenizer 265 | # Get the image features 266 | processor =clip_text_processor 267 | 268 | epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)]) 269 | best_mean_similarity = 0 270 | mean_similarity_list = [] 271 | for epoch in epochs: 272 | similarity = [] 273 | dataset = PromptDatasetCLIP(subject_name, image_dir, tokenizer, processor, epoch) 274 | dataloader = DataLoader(dataset, batch_size=32) 275 | for i in range(len(dataset)): 276 | image_inputs, prompt_inputs = dataset[i] 277 | if image_inputs is not None and prompt_inputs is not None: 278 | image_inputs['pixel_values'] = image_inputs['pixel_values'].to(device) 279 | prompt_inputs['input_ids'] = prompt_inputs['input_ids'].to(device) 280 | prompt_inputs['attention_mask'] = prompt_inputs['attention_mask'].to(device) 281 | # print(prompt_inputs) 282 | image_features = model.get_image_features(**image_inputs) 283 | text_features = model.get_text_features(**prompt_inputs) 284 | 285 | sim = cosine_similarity(image_features, text_features) 286 | 287 | # image_features = image_features / image_features.norm(p=2, dim=-1, keepdim=True) 288 | # text_features = text_features / text_features.norm(p=2, dim=-1, keepdim=True) 289 | # logit_scale = model.logit_scale.exp() 290 | # sim = torch.matmul(text_features, image_features.t()) * logit_scale 291 | similarity.append(sim.item()) 292 | 293 | if similarity: 294 | mean_similarity = torch.tensor(similarity).mean().item() 295 | mean_similarity_list.append(mean_similarity) 296 | best_mean_similarity = max(best_mean_similarity, mean_similarity) 297 | print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {mean_similarity}({best_mean_similarity})') 298 | else: 299 | mean_similarity_list.append(0) 300 | print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {0}({best_mean_similarity})') 301 | 302 | return mean_similarity_list 303 | 304 | 305 | clip_image_model = CLIPModel.from_pretrained("/data2/xxx/pretrained/clip-vit-large-patch14").to(device) 306 | clip_image_processor = AutoProcessor.from_pretrained("/data2/xxx/pretrained/clip-vit-large-patch14") 307 | 308 | def clip_image(subject_name, image_dir, dreambooth_dir='/data2/xxx/dataset/dreambooth'): 309 | criterion = 'clip_image' 310 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 311 | model = clip_image_model 312 | # Get the image features 313 | processor = clip_image_processor 314 | 315 | epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)]) 316 | best_mean_similarity = 0 317 | mean_similarity_list = [] 318 | for epoch in epochs: 319 | similarity = [] 320 | dataset = PairwiseImageDatasetCLIP(subject_name, dreambooth_dir, image_dir, processor, epoch) 321 | # dataset = SelfPairwiseImageDatasetCLIP(subject, './data', processor) 322 | 323 | for i in range(len(dataset)): 324 | inputs_A, inputs_B = dataset[i] 325 | if inputs_A is not None and inputs_B is not None: 326 | inputs_A['pixel_values'] = inputs_A['pixel_values'].to(device) 327 | inputs_B['pixel_values'] = inputs_B['pixel_values'].to(device) 328 | 329 | image_A_features = model.get_image_features(**inputs_A) 330 | image_B_features = model.get_image_features(**inputs_B) 331 | 332 | image_A_features = image_A_features / image_A_features.norm(p=2, dim=-1, keepdim=True) 333 | image_B_features = image_B_features / image_B_features.norm(p=2, dim=-1, keepdim=True) 334 | 335 | logit_scale = model.logit_scale.exp() 336 | sim = torch.matmul(image_A_features, image_B_features.t()) # * logit_scale 337 | similarity.append(sim.item()) 338 | 339 | if similarity: 340 | mean_similarity = torch.tensor(similarity).mean().item() 341 | best_mean_similarity = max(best_mean_similarity, mean_similarity) 342 | mean_similarity_list.append(mean_similarity) 343 | print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {mean_similarity}({best_mean_similarity})') 344 | else: 345 | mean_similarity_list.append(0) 346 | print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {0}({best_mean_similarity})') 347 | 348 | return mean_similarity_list 349 | 350 | 351 | dino_model = ViTModel.from_pretrained('/data2/xxx/pretrained/dino-vits16').to(device) 352 | deno_feature_extractor = ViTFeatureExtractor.from_pretrained('/data2/xxx/pretrained/dino-vits16') 353 | 354 | def dino(subject_name, image_dir, dreambooth_dir='/data2/xxx/dataset/dreambooth'): 355 | criterion = 'dino' 356 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 357 | model = dino_model 358 | feature_extractor = deno_feature_extractor 359 | 360 | epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)]) 361 | best_mean_similarity = 0 362 | mean_similarity_list = [] 363 | for epoch in epochs: 364 | similarity = [] 365 | # dataset = PairwiseImageDatasetDINO(subject, './data', image_dir, feature_extractor, epoch) 366 | dataset = PairwiseImageDatasetDINO(subject_name, dreambooth_dir, image_dir, feature_extractor, epoch) 367 | # dataset = SelfPairwiseImageDatasetDINO(subject, './data', feature_extractor) 368 | 369 | for i in range(len(dataset)): 370 | inputs_A, inputs_B = dataset[i] 371 | if inputs_A is not None and inputs_B is not None: 372 | inputs_A['pixel_values'] = inputs_A['pixel_values'].to(device) 373 | inputs_B['pixel_values'] = inputs_B['pixel_values'].to(device) 374 | 375 | outputs_A = model(**inputs_A) 376 | image_A_features = outputs_A.last_hidden_state[:, 0, :] 377 | 378 | outputs_B = model(**inputs_B) 379 | image_B_features = outputs_B.last_hidden_state[:, 0, :] 380 | 381 | image_A_features = image_A_features / image_A_features.norm(p=2, dim=-1, keepdim=True) 382 | image_B_features = image_B_features / image_B_features.norm(p=2, dim=-1, keepdim=True) 383 | 384 | sim = torch.matmul(image_A_features, image_B_features.t()) # * logit_scale 385 | similarity.append(sim.item()) 386 | 387 | mean_similarity = torch.tensor(similarity).mean().item() 388 | best_mean_similarity = max(best_mean_similarity, mean_similarity) 389 | mean_similarity_list.append(mean_similarity) 390 | print(f'epoch: {epoch}, criterion: {criterion}, mean_similarity: {mean_similarity}({best_mean_similarity})') 391 | 392 | return mean_similarity_list 393 | 394 | 395 | lpips_loss_fn = lpips.LPIPS(net='vgg').to(device) 396 | 397 | def lpips_image(subject_name, image_dir,dreambooth_dir='/data2/xxx/dataset/dreambooth'): 398 | criterion = 'lpips_image' 399 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 400 | # Set up the LPIPS model (vgg=True uses the VGG-based model from the paper) 401 | 402 | loss_fn = lpips_loss_fn 403 | # 有可能有些epoch没跑全 404 | epochs = sorted([int(epoch) for epoch in os.listdir(image_dir)]) 405 | mean_similarity_list = [] 406 | best_mean_similarity = 0 407 | for epoch in epochs: 408 | similarity = [] 409 | dataset = PairwiseImageDatasetLPIPS(subject_name, dreambooth_dir, image_dir, epoch) 410 | # dataset = SelfPairwiseImageDatasetLPIPS(subject, './data') 411 | 412 | for i in range(len(dataset)): 413 | image_A, image_B = dataset[i] 414 | if image_A is not None and image_B is not None: 415 | image_A = image_A.to(device) 416 | image_B = image_B.to(device) 417 | 418 | # Calculate LPIPS between the two images 419 | distance = loss_fn(image_A, image_B) 420 | 421 | similarity.append(distance.item()) 422 | 423 | mean_similarity = torch.tensor(similarity).mean().item() 424 | best_mean_similarity = max(best_mean_similarity, mean_similarity) 425 | mean_similarity_list.append(mean_similarity) 426 | print(f'epoch: {epoch}, criterion: LPIPS distance, mean_similarity: {mean_similarity}({best_mean_similarity})') 427 | 428 | return mean_similarity_list 429 | 430 | 431 | if __name__ == "__main__": 432 | image_dir = '/data2/xxx/ControlNet/oft-db/log_quant' 433 | 434 | subject_dirs, subject_names = [], [] 435 | for name in os.listdir(image_dir): 436 | if os.path.isdir(os.path.join(image_dir, name)): 437 | subject_dirs.append(os.path.join(image_dir, name)) 438 | subject_names.append(name) 439 | 440 | results_path = '/data2/xxx/ControlNet/oft-db/log_quant/results.json' 441 | 442 | results_dict = dict() 443 | if os.path.exists(results_path): 444 | with open(results_path, 'r') as f: 445 | results = f.__iter__() 446 | while True: 447 | try: 448 | result_json = json.loads(next(results)) 449 | results_dict.update(result_json) 450 | 451 | except StopIteration: 452 | print("finish extraction.") 453 | break 454 | 455 | for idx in range(len(subject_names)): 456 | subject_name = subject_names[idx] 457 | subject_dir = subject_dirs[idx] 458 | print(f'evaluating {subject_dir}') 459 | dino_sim = dino(subject_name, subject_dir) 460 | clip_i_sim = clip_image(subject_name, subject_dir) 461 | clip_t_sim = clip_text(subject_name, subject_dir) 462 | lpips_sim = lpips_image(subject_name, subject_dir) 463 | 464 | subject_result = {'DINO': dino_sim, 'CLIP-I': clip_i_sim, 'CLIP-T': clip_t_sim, 'LPIPS': lpips_sim} 465 | print(subject_result) 466 | 467 | with open(results_path, 'a') as f: 468 | json_string = json.dumps({subject_name: subject_result}) 469 | f.write(json_string + "\n") 470 | -------------------------------------------------------------------------------- /get_results.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import math 17 | import os 18 | from functools import reduce 19 | import numpy as np 20 | 21 | import json 22 | 23 | if __name__ == "__main__": 24 | results_path = '/data2/xxx/ControlNet/oft-db/log_quant/results.json' 25 | results_dict = dict() 26 | if os.path.exists(results_path): 27 | with open(results_path, 'r') as f: 28 | results = f.__iter__() 29 | while True: 30 | try: 31 | result_json = json.loads(next(results)) 32 | results_dict.update(result_json) 33 | except StopIteration: 34 | print("finish extraction.") 35 | break 36 | else: 37 | raise NotImplementedError 38 | total_result = np.zeros(4) 39 | metric_name_list = ['DINO', 'CLIP-I', 'CLIP-T', 'LPIPS'] 40 | except_list = [] 41 | num_samples = 0 42 | print(len((results_dict.keys()))) 43 | for subject_name, subject_results in results_dict.items(): 44 | if subject_name in except_list: 45 | continue 46 | num_samples += 1 47 | metric_results_percent = None 48 | for metric_name, metric_results in subject_results.items(): 49 | metric_results = [0 if np.isnan(r) else r for r in metric_results] 50 | try: 51 | metric_results_norm = np.array(metric_results) / (max(metric_results) - min(metric_results)) 52 | except: 53 | print(subject_name) 54 | if metric_results_percent is None: 55 | metric_results_percent = metric_results_norm 56 | else: 57 | metric_results_percent += metric_results_norm 58 | 59 | subject_results_max_idx = np.argmax(metric_results_percent) 60 | for idx, metric_name in enumerate(metric_name_list): 61 | total_result[idx] += subject_results[metric_name][subject_results_max_idx] 62 | total_result /= num_samples 63 | print(f'DINO: {total_result[0]}, CLIP-I: {total_result[1]}, CLIP-T: {total_result[2]}, LPIPS: {total_result[3]}') 64 | -------------------------------------------------------------------------------- /train_dreambooth_quant.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import hashlib 18 | import logging 19 | import math 20 | import os 21 | import warnings 22 | from pathlib import Path 23 | import torch.nn as nn 24 | import numpy as np 25 | import torch 26 | import torch.nn.functional as F 27 | import torch.utils.checkpoint 28 | import transformers 29 | from accelerate import Accelerator 30 | from accelerate.logging import get_logger 31 | from accelerate.utils import ProjectConfiguration, set_seed 32 | from huggingface_hub import create_repo, upload_folder 33 | from packaging import version 34 | from PIL import Image 35 | from torch.utils.data import Dataset 36 | from torchvision import transforms 37 | from tqdm.auto import tqdm 38 | from transformers import AutoTokenizer, PretrainedConfig 39 | 40 | import diffusers 41 | from diffusers import ( 42 | AutoencoderKL, 43 | DDPMScheduler, 44 | DiffusionPipeline, 45 | DPMSolverMultistepScheduler, 46 | UNet2DConditionModel, 47 | ) 48 | from utils.quant_model import QuantUnetWarp 49 | from diffusers.optimization import get_scheduler 50 | from diffusers.utils import check_min_version, is_wandb_available 51 | from diffusers.utils.import_utils import is_xformers_available 52 | from utils.intlora_shift import IntLoRA_SHIFT 53 | from utils.intlora_mul import IntLoRA_MUL 54 | 55 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 56 | check_min_version("0.16.0.dev0") 57 | logger = get_logger(__name__) 58 | 59 | 60 | 61 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): 62 | text_encoder_config = PretrainedConfig.from_pretrained( 63 | pretrained_model_name_or_path, 64 | subfolder="text_encoder", 65 | revision=revision, 66 | ) 67 | model_class = text_encoder_config.architectures[0] 68 | 69 | if model_class == "CLIPTextModel": 70 | from transformers import CLIPTextModel 71 | 72 | return CLIPTextModel 73 | elif model_class == "RobertaSeriesModelWithTransformation": 74 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation 75 | 76 | return RobertaSeriesModelWithTransformation 77 | else: 78 | raise ValueError(f"{model_class} is not supported.") 79 | 80 | 81 | def parse_args(input_args=None): 82 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 83 | parser.add_argument( 84 | "--pretrained_model_name_or_path", 85 | type=str, 86 | default=None, 87 | required=True, 88 | help="Path to pretrained model or model identifier from huggingface.co/models.", 89 | ) 90 | 91 | parser.add_argument( 92 | "--intlora", 93 | type=str, 94 | default='MUL', 95 | required=False, 96 | ) 97 | 98 | 99 | parser.add_argument( 100 | "--nbits", 101 | type=int, 102 | default=8, 103 | required=False, 104 | ) 105 | 106 | parser.add_argument("--rank", type=int, default=4, help="The inner rank of the lora layer") 107 | 108 | parser.add_argument("--act_nbits", type=int, default=8, help="The nbits of the activatin quantization") 109 | 110 | parser.add_argument( 111 | "--use_activation_quant", 112 | default=False, 113 | action="store_true", 114 | help="Whether to use the ativation quantization") 115 | 116 | parser.add_argument( 117 | "--revision", 118 | type=str, 119 | default=None, 120 | required=False, 121 | help="Revision of pretrained model identifier from huggingface.co/models.", 122 | ) 123 | parser.add_argument( 124 | "--tokenizer_name", 125 | type=str, 126 | default=None, 127 | help="Pretrained tokenizer name or path if not the same as model_name", 128 | ) 129 | parser.add_argument( 130 | "--instance_data_dir", 131 | type=str, 132 | default=None, 133 | required=True, 134 | help="A folder containing the training data of instance images.", 135 | ) 136 | parser.add_argument( 137 | "--class_data_dir", 138 | type=str, 139 | default=None, 140 | required=False, 141 | help="A folder containing the training data of class images.", 142 | ) 143 | parser.add_argument( 144 | "--instance_prompt", 145 | type=str, 146 | default=None, 147 | required=True, 148 | help="The prompt with identifier specifying the instance", 149 | ) 150 | parser.add_argument( 151 | "--class_prompt", 152 | type=str, 153 | default=None, 154 | help="The prompt to specify images in the same class as provided instance images.", 155 | ) 156 | parser.add_argument( 157 | "--validation_prompt", 158 | type=str, 159 | default=None, 160 | help="A prompt that is used during validation to verify that the model is learning.", 161 | ) 162 | parser.add_argument( 163 | "--test_prompt", 164 | type=str, 165 | default=None, 166 | help="A prompt that is used during validation to verify that the model is keeps class prior.", 167 | ) 168 | parser.add_argument( 169 | "--num_validation_images", 170 | type=int, 171 | default=8, 172 | help="Number of images that should be generated during validation with `validation_prompt`.", 173 | ) 174 | parser.add_argument( 175 | "--validation_epochs", 176 | type=int, 177 | default=50, 178 | help=( 179 | "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" 180 | " `args.validation_prompt` multiple times: `args.num_validation_images`." 181 | ), 182 | ) 183 | parser.add_argument( 184 | "--with_prior_preservation", 185 | default=False, 186 | action="store_true", 187 | help="Flag to add prior preservation loss.", 188 | ) 189 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") 190 | parser.add_argument( 191 | "--num_class_images", 192 | type=int, 193 | default=100, 194 | help=( 195 | "Minimal class images for prior preservation loss. If there are not enough images already present in" 196 | " class_data_dir, additional images will be sampled with class_prompt." 197 | ), 198 | ) 199 | parser.add_argument( 200 | "--output_dir", 201 | type=str, 202 | default="lora-dreambooth-model", 203 | help="The output directory where the model predictions and checkpoints will be written.", 204 | ) 205 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 206 | parser.add_argument( 207 | "--resolution", 208 | type=int, 209 | default=512, 210 | help=( 211 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 212 | " resolution" 213 | ), 214 | ) 215 | parser.add_argument( 216 | "--center_crop", 217 | default=False, 218 | action="store_true", 219 | help=( 220 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 221 | " cropped. The images will be resized to the resolution first before cropping." 222 | ), 223 | ) 224 | parser.add_argument( 225 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 226 | ) 227 | parser.add_argument( 228 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 229 | ) 230 | parser.add_argument("--num_train_epochs", type=int, default=1) 231 | parser.add_argument( 232 | "--max_train_steps", 233 | type=int, 234 | default=None, 235 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 236 | ) 237 | parser.add_argument( 238 | "--checkpointing_steps", 239 | type=int, 240 | default=500, 241 | help=( 242 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" 243 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" 244 | " training using `--resume_from_checkpoint`." 245 | ), 246 | ) 247 | parser.add_argument( 248 | "--checkpoints_total_limit", 249 | type=int, 250 | default=None, 251 | help=( 252 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." 253 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" 254 | " for more docs" 255 | ), 256 | ) 257 | parser.add_argument( 258 | "--resume_from_checkpoint", 259 | type=str, 260 | default=None, 261 | help=( 262 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 263 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 264 | ), 265 | ) 266 | parser.add_argument( 267 | "--gradient_accumulation_steps", 268 | type=int, 269 | default=1, 270 | help="Number of updates steps to accumulate before performing a backward/update pass.", 271 | ) 272 | parser.add_argument( 273 | "--gradient_checkpointing", 274 | action="store_true", 275 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 276 | ) 277 | parser.add_argument( 278 | "--learning_rate", 279 | type=float, 280 | default=5e-4, 281 | help="Initial learning rate (after the potential warmup period) to use.", 282 | ) 283 | parser.add_argument( 284 | "--scale_lr", 285 | action="store_true", 286 | default=False, 287 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 288 | ) 289 | parser.add_argument( 290 | "--lr_scheduler", 291 | type=str, 292 | default="constant", 293 | help=( 294 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 295 | ' "constant", "constant_with_warmup"]' 296 | ), 297 | ) 298 | parser.add_argument( 299 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 300 | ) 301 | parser.add_argument( 302 | "--lr_num_cycles", 303 | type=int, 304 | default=1, 305 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 306 | ) 307 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 308 | parser.add_argument( 309 | "--dataloader_num_workers", 310 | type=int, 311 | default=0, 312 | help=( 313 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 314 | ), 315 | ) 316 | parser.add_argument( 317 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 318 | ) 319 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 320 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 321 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 322 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 323 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 324 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 325 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 326 | parser.add_argument( 327 | "--hub_model_id", 328 | type=str, 329 | default=None, 330 | help="The name of the repository to keep in sync with the local `output_dir`.", 331 | ) 332 | parser.add_argument( 333 | "--logging_dir", 334 | type=str, 335 | default="logs", 336 | help=( 337 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 338 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 339 | ), 340 | ) 341 | parser.add_argument( 342 | "--allow_tf32", 343 | action="store_true", 344 | help=( 345 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 346 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 347 | ), 348 | ) 349 | parser.add_argument( 350 | "--report_to", 351 | type=str, 352 | default="tensorboard", 353 | help=( 354 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 355 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 356 | ), 357 | ) 358 | parser.add_argument( 359 | "--mixed_precision", 360 | type=str, 361 | default=None, 362 | choices=["no", "fp16", "bf16"], 363 | help=( 364 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 365 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 366 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 367 | ), 368 | ) 369 | parser.add_argument( 370 | "--prior_generation_precision", 371 | type=str, 372 | default=None, 373 | choices=["no", "fp32", "fp16", "bf16"], 374 | help=( 375 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 376 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." 377 | ), 378 | ) 379 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 380 | parser.add_argument( 381 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 382 | ) 383 | parser.add_argument( 384 | "--name", 385 | type=str, 386 | help=( 387 | "The name of the current experiment run, consists of [data]-[prompt]" 388 | ), 389 | ) 390 | 391 | 392 | 393 | 394 | if input_args is not None: 395 | args = parser.parse_args(input_args) 396 | else: 397 | args = parser.parse_args() 398 | 399 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 400 | if env_local_rank != -1 and env_local_rank != args.local_rank: 401 | args.local_rank = env_local_rank 402 | 403 | if args.with_prior_preservation: 404 | if args.class_data_dir is None: 405 | raise ValueError("You must specify a data directory for class images.") 406 | if args.class_prompt is None: 407 | raise ValueError("You must specify prompt for class images.") 408 | else: 409 | # logger is not available yet 410 | if args.class_data_dir is not None: 411 | warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") 412 | if args.class_prompt is not None: 413 | warnings.warn("You need not use --class_prompt without --with_prior_preservation.") 414 | 415 | return args 416 | 417 | 418 | class DreamBoothDataset(Dataset): 419 | """ 420 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 421 | It pre-processes the images and the tokenizes prompts. 422 | """ 423 | 424 | def __init__( 425 | self, 426 | instance_data_root, 427 | instance_prompt, 428 | tokenizer, 429 | class_data_root=None, 430 | class_prompt=None, 431 | class_num=None, 432 | size=512, 433 | center_crop=False, 434 | ): 435 | self.size = size 436 | self.center_crop = center_crop 437 | self.tokenizer = tokenizer 438 | 439 | self.instance_data_root = Path(instance_data_root) 440 | if not self.instance_data_root.exists(): 441 | raise ValueError("Instance images root doesn't exists.") 442 | 443 | self.instance_images_path = list(Path(instance_data_root).iterdir()) 444 | self.num_instance_images = len(self.instance_images_path) 445 | self.instance_prompt = instance_prompt 446 | self._length = self.num_instance_images 447 | 448 | if class_data_root is not None: 449 | self.class_data_root = Path(class_data_root) 450 | self.class_data_root.mkdir(parents=True, exist_ok=True) 451 | self.class_images_path = list(self.class_data_root.iterdir()) 452 | if class_num is not None: 453 | self.num_class_images = min(len(self.class_images_path), class_num) 454 | else: 455 | self.num_class_images = len(self.class_images_path) 456 | self._length = max(self.num_class_images, self.num_instance_images) 457 | self.class_prompt = class_prompt 458 | else: 459 | self.class_data_root = None 460 | 461 | self.image_transforms = transforms.Compose( 462 | [ 463 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 464 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 465 | transforms.ToTensor(), 466 | transforms.Normalize([0.5], [0.5]), 467 | ] 468 | ) 469 | 470 | def __len__(self): 471 | return self._length 472 | 473 | def __getitem__(self, index): 474 | example = {} 475 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) 476 | if not instance_image.mode == "RGB": 477 | instance_image = instance_image.convert("RGB") 478 | example["instance_images"] = self.image_transforms(instance_image) 479 | example["instance_prompt_ids"] = self.tokenizer( 480 | self.instance_prompt, 481 | truncation=True, 482 | padding="max_length", 483 | max_length=self.tokenizer.model_max_length, 484 | return_tensors="pt", 485 | ).input_ids 486 | 487 | if self.class_data_root: 488 | class_image = Image.open(self.class_images_path[index % self.num_class_images]) 489 | if not class_image.mode == "RGB": 490 | class_image = class_image.convert("RGB") 491 | example["class_images"] = self.image_transforms(class_image) 492 | example["class_prompt_ids"] = self.tokenizer( 493 | self.class_prompt, 494 | truncation=True, 495 | padding="max_length", 496 | max_length=self.tokenizer.model_max_length, 497 | return_tensors="pt", 498 | ).input_ids 499 | 500 | return example 501 | 502 | 503 | def collate_fn(examples, with_prior_preservation=False): 504 | input_ids = [example["instance_prompt_ids"] for example in examples] 505 | pixel_values = [example["instance_images"] for example in examples] 506 | 507 | # Concat class and instance examples for prior preservation. 508 | # We do this to avoid doing two forward passes. 509 | if with_prior_preservation: 510 | input_ids += [example["class_prompt_ids"] for example in examples] 511 | pixel_values += [example["class_images"] for example in examples] 512 | 513 | pixel_values = torch.stack(pixel_values) 514 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 515 | 516 | input_ids = torch.cat(input_ids, dim=0) 517 | 518 | batch = { 519 | "input_ids": input_ids, 520 | "pixel_values": pixel_values, 521 | } 522 | return batch 523 | 524 | 525 | class PromptDataset(Dataset): 526 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 527 | 528 | def __init__(self, prompt, num_samples): 529 | self.prompt = prompt 530 | self.num_samples = num_samples 531 | 532 | def __len__(self): 533 | return self.num_samples 534 | 535 | def __getitem__(self, index): 536 | example = {} 537 | example["prompt"] = self.prompt 538 | example["index"] = index 539 | return example 540 | 541 | 542 | def main(args): 543 | logging_dir = Path(args.output_dir, args.logging_dir) 544 | 545 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) # total_limit=args.checkpoints_total_limit) 546 | 547 | wandb_init = { 548 | "wandb": { 549 | "name": args.name, 550 | # "project": args.project, 551 | } 552 | } 553 | 554 | accelerator = Accelerator( 555 | gradient_accumulation_steps=args.gradient_accumulation_steps, 556 | mixed_precision=args.mixed_precision, 557 | log_with=args.report_to, 558 | project_config=accelerator_project_config, 559 | ) 560 | 561 | if args.report_to == "wandb": 562 | if not is_wandb_available(): 563 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 564 | import wandb 565 | 566 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 567 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 568 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 569 | # Make one log on every process with the configuration for debugging. 570 | logging.basicConfig( 571 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 572 | datefmt="%m/%d/%Y %H:%M:%S", 573 | level=logging.INFO, 574 | ) 575 | logger.info(accelerator.state, main_process_only=False) 576 | if accelerator.is_local_main_process: 577 | transformers.utils.logging.set_verbosity_warning() 578 | diffusers.utils.logging.set_verbosity_info() 579 | else: 580 | transformers.utils.logging.set_verbosity_error() 581 | diffusers.utils.logging.set_verbosity_error() 582 | 583 | # If passed along, set the training seed now. 584 | if args.seed is not None: 585 | set_seed(args.seed) 586 | 587 | # Generate class images if prior preservation is enabled. 588 | if args.with_prior_preservation: 589 | class_images_dir = Path(args.class_data_dir) 590 | if not class_images_dir.exists(): 591 | class_images_dir.mkdir(parents=True) 592 | cur_class_images = len(list(class_images_dir.iterdir())) 593 | 594 | if cur_class_images < args.num_class_images: 595 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 596 | if args.prior_generation_precision == "fp32": 597 | torch_dtype = torch.float32 598 | elif args.prior_generation_precision == "fp16": 599 | torch_dtype = torch.float16 600 | elif args.prior_generation_precision == "bf16": 601 | torch_dtype = torch.bfloat16 602 | pipeline = DiffusionPipeline.from_pretrained( 603 | args.pretrained_model_name_or_path, 604 | torch_dtype=torch_dtype, 605 | safety_checker=None, 606 | revision=args.revision, 607 | ) 608 | pipeline.set_progress_bar_config(disable=True) 609 | 610 | num_new_images = args.num_class_images - cur_class_images 611 | logger.info(f"Number of class images to sample: {num_new_images}.") 612 | 613 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) 614 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) 615 | 616 | sample_dataloader = accelerator.prepare(sample_dataloader) 617 | pipeline.to(accelerator.device) 618 | 619 | for example in tqdm( 620 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 621 | ): 622 | images = pipeline(example["prompt"]).images 623 | 624 | for i, image in enumerate(images): 625 | hash_image = hashlib.sha1(image.tobytes()).hexdigest() 626 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" 627 | image.save(image_filename) 628 | 629 | del pipeline 630 | if torch.cuda.is_available(): 631 | torch.cuda.empty_cache() 632 | 633 | # Handle the repository creation 634 | if accelerator.is_main_process: 635 | if args.output_dir is not None: 636 | os.makedirs(args.output_dir, exist_ok=True) 637 | 638 | if args.push_to_hub: 639 | repo_id = create_repo( 640 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 641 | ).repo_id 642 | 643 | # Load the tokenizer 644 | if args.tokenizer_name: 645 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) 646 | elif args.pretrained_model_name_or_path: 647 | tokenizer = AutoTokenizer.from_pretrained( 648 | args.pretrained_model_name_or_path, 649 | subfolder="tokenizer", 650 | revision=args.revision, 651 | use_fast=False, 652 | ) 653 | 654 | # import correct text encoder class 655 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) 656 | 657 | # Load scheduler and models 658 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 659 | text_encoder = text_encoder_cls.from_pretrained( 660 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 661 | ) 662 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) 663 | unet = UNet2DConditionModel.from_pretrained( 664 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision 665 | ) 666 | 667 | # We only train the additional adapter OFT layers 668 | vae.requires_grad_(False) 669 | text_encoder.requires_grad_(False) 670 | unet.requires_grad_(False) 671 | 672 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 673 | # as these models are only used for inference, keeping weights in full precision is not required. 674 | weight_dtype = torch.float32 675 | if accelerator.mixed_precision == "fp16": 676 | weight_dtype = torch.float16 677 | elif accelerator.mixed_precision == "bf16": 678 | weight_dtype = torch.bfloat16 679 | 680 | # Move unet, vae and text_encoder to device and cast to weight_dtype 681 | unet.to(accelerator.device, dtype=weight_dtype) 682 | vae.to(accelerator.device, dtype=weight_dtype) 683 | text_encoder.to(accelerator.device, dtype=weight_dtype) 684 | 685 | if args.enable_xformers_memory_efficient_attention: 686 | if is_xformers_available(): 687 | import xformers 688 | 689 | xformers_version = version.parse(xformers.__version__) 690 | if xformers_version == version.parse("0.0.16"): 691 | logger.warn( 692 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 693 | ) 694 | unet.enable_xformers_memory_efficient_attention() 695 | else: 696 | raise ValueError("xformers is not available. Make sure it is installed correctly") 697 | 698 | 699 | # --------------------------------------------------------------------------------------------------------------- 700 | unet = QuantUnetWarp(unet,args) 701 | opt_params = [] 702 | for name, module in unet.named_modules(): 703 | if isinstance(module, (IntLoRA_MUL,IntLoRA_SHIFT)): 704 | opt_params += list(module.loraA.parameters()) + list(module.loraB.parameters()) 705 | 706 | # --------------------------------------------------------------------------------------------------------------- 707 | # Enable TF32 for faster training on Ampere GPUs, 708 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 709 | if args.allow_tf32: 710 | torch.backends.cuda.matmul.allow_tf32 = True 711 | 712 | if args.scale_lr: 713 | args.learning_rate = ( 714 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 715 | ) 716 | 717 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 718 | if args.use_8bit_adam: 719 | try: 720 | import bitsandbytes as bnb 721 | except ImportError: 722 | raise ImportError( 723 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 724 | ) 725 | 726 | optimizer_class = bnb.optim.AdamW8bit 727 | else: 728 | optimizer_class = torch.optim.AdamW 729 | 730 | # Optimizer creation 731 | optimizer = optimizer_class( 732 | opt_params, # quant_layers.parameters(), #opt_params, 733 | lr=args.learning_rate, # 6e-5 734 | betas=(args.adam_beta1, args.adam_beta2), 735 | weight_decay=args.adam_weight_decay, # 0.01 736 | eps=args.adam_epsilon, # 1e-8 737 | ) 738 | 739 | # Dataset and DataLoaders creation: 740 | train_dataset = DreamBoothDataset( 741 | instance_data_root=args.instance_data_dir, 742 | instance_prompt=args.instance_prompt, 743 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, 744 | class_prompt=args.class_prompt, 745 | class_num=args.num_class_images, 746 | tokenizer=tokenizer, 747 | size=args.resolution, 748 | center_crop=args.center_crop, 749 | ) 750 | 751 | train_dataloader = torch.utils.data.DataLoader( 752 | train_dataset, 753 | batch_size=args.train_batch_size, 754 | shuffle=True, 755 | collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), 756 | num_workers=args.dataloader_num_workers, 757 | ) 758 | 759 | # Scheduler and math around the number of training steps. 760 | overrode_max_train_steps = False 761 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 762 | if args.max_train_steps is None: 763 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 764 | overrode_max_train_steps = True 765 | 766 | lr_scheduler = get_scheduler( 767 | args.lr_scheduler, 768 | optimizer=optimizer, 769 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 770 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 771 | num_cycles=args.lr_num_cycles, 772 | power=args.lr_power, 773 | ) 774 | 775 | # Prepare everything with our `accelerator`. 776 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 777 | unet, optimizer, train_dataloader, lr_scheduler 778 | ) 779 | 780 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 781 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 782 | if overrode_max_train_steps: 783 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 784 | # Afterwards we recalculate our number of training epochs 785 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 786 | 787 | # We need to initialize the trackers we use, and also store our configuration. 788 | # The trackers initializes automatically on the main process. 789 | if accelerator.is_main_process: 790 | accelerator.init_trackers("dreambooth-quant", config=vars(args), init_kwargs=wandb_init) 791 | 792 | # Train! 793 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 794 | 795 | logger.info("***** Running training *****") 796 | logger.info(f" Num examples = {len(train_dataset)}") 797 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 798 | logger.info(f" Num Epochs = {args.num_train_epochs}") 799 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 800 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 801 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 802 | logger.info(f" Total optimization steps = {args.max_train_steps}") 803 | global_step = 0 804 | first_epoch = 0 805 | 806 | # Potentially load in the weights and states from a previous save 807 | if args.resume_from_checkpoint: 808 | if args.resume_from_checkpoint != "latest": 809 | path = os.path.basename(args.resume_from_checkpoint) 810 | else: 811 | # Get the mos recent checkpoint 812 | dirs = os.listdir(args.output_dir) 813 | dirs = [d for d in dirs if d.startswith("checkpoint")] 814 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 815 | path = dirs[-1] if len(dirs) > 0 else None 816 | 817 | if path is None: 818 | accelerator.print( 819 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 820 | ) 821 | args.resume_from_checkpoint = None 822 | else: 823 | accelerator.print(f"Resuming from checkpoint {path}") 824 | accelerator.load_state(os.path.join(args.output_dir, path)) 825 | global_step = int(path.split("-")[1]) 826 | 827 | resume_global_step = global_step * args.gradient_accumulation_steps 828 | first_epoch = global_step // num_update_steps_per_epoch 829 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) 830 | 831 | # Only show the progress bar once on each machine. 832 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 833 | progress_bar.set_description("Steps") 834 | 835 | 836 | for epoch in range(first_epoch, args.num_train_epochs): 837 | unet.train() 838 | 839 | for step, batch in enumerate(train_dataloader): 840 | # Skip steps until we reach the resumed step 841 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 842 | if step % args.gradient_accumulation_steps == 0: 843 | progress_bar.update(1) 844 | continue 845 | 846 | with accelerator.accumulate(unet): 847 | # Convert images to latent space 848 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 849 | latents = latents * vae.config.scaling_factor 850 | 851 | # Sample noise that we'll add to the latents 852 | noise = torch.randn_like(latents) 853 | bsz = latents.shape[0] 854 | # Sample a random timestep for each image 855 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 856 | timesteps = timesteps.long() 857 | 858 | # Add noise to the latents according to the noise magnitude at each timestep 859 | # (this is the forward diffusion process) 860 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 861 | 862 | # Get the text embedding for conditioning 863 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 864 | 865 | # Predict the noise residual 866 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 867 | 868 | # Get the target for loss depending on the prediction type 869 | if noise_scheduler.config.prediction_type == "epsilon": 870 | target = noise 871 | elif noise_scheduler.config.prediction_type == "v_prediction": 872 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 873 | else: 874 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 875 | 876 | if args.with_prior_preservation: 877 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 878 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 879 | target, target_prior = torch.chunk(target, 2, dim=0) 880 | 881 | # Compute instance loss 882 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 883 | 884 | # Compute prior loss 885 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 886 | 887 | # Add the prior loss to the instance loss. 888 | loss = loss + args.prior_loss_weight * prior_loss 889 | else: 890 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 891 | 892 | accelerator.backward(loss) 893 | if accelerator.sync_gradients: 894 | params_to_clip = opt_params # quant_layers.parameters() # opt_params 895 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 896 | optimizer.step() 897 | lr_scheduler.step() 898 | optimizer.zero_grad() 899 | 900 | # Checks if the accelerator has performed an optimization step behind the scenes 901 | if accelerator.sync_gradients: 902 | progress_bar.update(1) 903 | global_step += 1 904 | 905 | if global_step % args.checkpointing_steps == 0: 906 | if accelerator.is_main_process: 907 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 908 | accelerator.save_state(save_path) 909 | logger.info(f"Saved state to {save_path}") 910 | 911 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 912 | progress_bar.set_postfix(**logs) 913 | accelerator.log(logs, step=global_step) 914 | 915 | if global_step >= args.max_train_steps: 916 | break 917 | 918 | if accelerator.is_main_process: 919 | if args.validation_prompt is not None and epoch % args.validation_epochs == 0: # and epoch > 1: 920 | logger.info( 921 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 922 | f" {args.validation_prompt}." 923 | ) 924 | 925 | # create pipeline 926 | pipeline = DiffusionPipeline.from_pretrained( 927 | args.pretrained_model_name_or_path, 928 | unet= accelerator.unwrap_model(unet.model) if isinstance(unet,QuantUnetWarp) else accelerator.unwrap_model(unet), 929 | text_encoder=accelerator.unwrap_model(text_encoder), 930 | revision=args.revision, 931 | torch_dtype=weight_dtype, 932 | ) 933 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 934 | pipeline = pipeline.to(accelerator.device) 935 | pipeline.set_progress_bar_config(disable=True) 936 | 937 | # run inference 938 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 939 | images = [ 940 | pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] 941 | for _ in range(args.num_validation_images) 942 | ] 943 | 944 | for tracker in accelerator.trackers: 945 | if tracker.name == "tensorboard": 946 | np_images = np.stack([np.asarray(img) for img in images]) 947 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") 948 | if tracker.name == "wandb": 949 | tracker.log( 950 | { 951 | "validation": [ 952 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") 953 | for i, image in enumerate(images) 954 | ] 955 | } 956 | ) 957 | 958 | # Create the output directory if it doesn't exist 959 | tmp_dir = os.path.join(args.output_dir, str(epoch)) 960 | if not os.path.exists(tmp_dir): 961 | os.makedirs(tmp_dir) 962 | 963 | for i, image in enumerate(images): 964 | np_image = np.array(image) 965 | pil_image = Image.fromarray(np_image) 966 | pil_image.save(os.path.join(args.output_dir, str(epoch), f"image_{i}.png")) 967 | 968 | del pipeline 969 | torch.cuda.empty_cache() 970 | 971 | 972 | # Save the oft layers 973 | accelerator.wait_for_everyone() 974 | if accelerator.is_main_process: 975 | unet = unet.to(torch.float32) 976 | unet.save_attn_procs(args.output_dir) 977 | 978 | # Final inference 979 | # Load previous pipeline 980 | pipeline = DiffusionPipeline.from_pretrained( 981 | args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype 982 | ) 983 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 984 | pipeline = pipeline.to(accelerator.device) 985 | 986 | # load attention processors 987 | pipeline.unet.load_attn_procs(args.output_dir) 988 | 989 | # run inference 990 | if args.validation_prompt and args.num_validation_images > 0: 991 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None 992 | images = [ 993 | pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] 994 | for _ in range(args.num_validation_images) 995 | ] 996 | 997 | for tracker in accelerator.trackers: 998 | if tracker.name == "tensorboard": 999 | np_images = np.stack([np.asarray(img) for img in images]) 1000 | tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") 1001 | if tracker.name == "wandb": 1002 | tracker.log( 1003 | { 1004 | "test": [ 1005 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") 1006 | for i, image in enumerate(images) 1007 | ] 1008 | } 1009 | ) 1010 | 1011 | accelerator.end_training() 1012 | 1013 | 1014 | if __name__ == "__main__": 1015 | args = parse_args() 1016 | main(args) 1017 | -------------------------------------------------------------------------------- /train_dreambooth_quant.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME="/data/guohang/pretrained/stable-diffusion-v1-5" 2 | 3 | for idx in {0..749} 4 | do 5 | prompt_idx=$((idx % 25)) 6 | class_idx=$((idx / 25)) 7 | 8 | 9 | # Define the unique_token, class_tokens, and subject_names 10 | unique_token="qwe" 11 | subject_names=( 12 | "backpack" "backpack_dog" "bear_plushie" "berry_bowl" "can" 13 | "candle" "cat" "cat2" "clock" "colorful_sneaker" 14 | "dog" "dog2" "dog3" "dog5" "dog6" 15 | "dog7" "dog8" "duck_toy" "fancy_boot" "grey_sloth_plushie" 16 | "monster_toy" "pink_sunglasses" "poop_emoji" "rc_car" "red_cartoon" 17 | "robot_toy" "shiny_sneaker" "teapot" "vase" "wolf_plushie" 18 | ) 19 | 20 | class_tokens=( 21 | "backpack" "backpack" "stuffed animal" "bowl" "can" 22 | "candle" "cat" "cat" "clock" "sneaker" 23 | "dog" "dog" "dog" "dog" "dog" 24 | "dog" "dog" "toy" "boot" "stuffed animal" 25 | "toy" "glasses" "toy" "toy" "cartoon" 26 | "toy" "sneaker" "teapot" "vase" "stuffed animal" 27 | ) 28 | 29 | class_token=${class_tokens[$class_idx]} 30 | selected_subject=${subject_names[$class_idx]} 31 | 32 | if [[ $class_idx =~ ^(0|1|2|3|4|5|8|9|17|18|19|20|21|22|23|24|25|26|27|28|29)$ ]]; then 33 | prompt_list=( 34 | "a ${unique_token} ${class_token} in the jungle" 35 | "a ${unique_token} ${class_token} in the snow" 36 | "a ${unique_token} ${class_token} on the beach" 37 | "a ${unique_token} ${class_token} on a cobblestone street" 38 | "a ${unique_token} ${class_token} on top of pink fabric" 39 | "a ${unique_token} ${class_token} on top of a wooden floor" 40 | "a ${unique_token} ${class_token} with a city in the background" 41 | "a ${unique_token} ${class_token} with a mountain in the background" 42 | "a ${unique_token} ${class_token} with a blue house in the background" 43 | "a ${unique_token} ${class_token} on top of a purple rug in a forest" 44 | "a ${unique_token} ${class_token} with a wheat field in the background" 45 | "a ${unique_token} ${class_token} with a tree and autumn leaves in the background" 46 | "a ${unique_token} ${class_token} with the Eiffel Tower in the background" 47 | "a ${unique_token} ${class_token} floating on top of water" 48 | "a ${unique_token} ${class_token} floating in an ocean of milk" 49 | "a ${unique_token} ${class_token} on top of green grass with sunflowers around it" 50 | "a ${unique_token} ${class_token} on top of a mirror" 51 | "a ${unique_token} ${class_token} on top of the sidewalk in a crowded street" 52 | "a ${unique_token} ${class_token} on top of a dirt road" 53 | "a ${unique_token} ${class_token} on top of a white rug" 54 | "a red ${unique_token} ${class_token}" 55 | "a purple ${unique_token} ${class_token}" 56 | "a shiny ${unique_token} ${class_token}" 57 | "a wet ${unique_token} ${class_token}" 58 | "a cube shaped ${unique_token} ${class_token}" 59 | ) 60 | 61 | prompt_test_list=( 62 | "a ${class_token} in the jungle" 63 | "a ${class_token} in the snow" 64 | "a ${class_token} on the beach" 65 | "a ${class_token} on a cobblestone street" 66 | "a ${class_token} on top of pink fabric" 67 | "a ${class_token} on top of a wooden floor" 68 | "a ${class_token} with a city in the background" 69 | "a ${class_token} with a mountain in the background" 70 | "a ${class_token} with a blue house in the background" 71 | "a ${class_token} on top of a purple rug in a forest" 72 | "a ${class_token} with a wheat field in the background" 73 | "a ${class_token} with a tree and autumn leaves in the background" 74 | "a ${class_token} with the Eiffel Tower in the background" 75 | "a ${class_token} floating on top of water" 76 | "a ${class_token} floating in an ocean of milk" 77 | "a ${class_token} on top of green grass with sunflowers around it" 78 | "a ${class_token} on top of a mirror" 79 | "a ${class_token} on top of the sidewalk in a crowded street" 80 | "a ${class_token} on top of a dirt road" 81 | "a ${class_token} on top of a white rug" 82 | "a red ${class_token}" 83 | "a purple ${class_token}" 84 | "a shiny ${class_token}" 85 | "a wet ${class_token}" 86 | "a cube shaped ${class_token}" 87 | ) 88 | 89 | else 90 | prompt_list=( 91 | "a ${unique_token} ${class_token} in the jungle" 92 | "a ${unique_token} ${class_token} in the snow" 93 | "a ${unique_token} ${class_token} on the beach" 94 | "a ${unique_token} ${class_token} on a cobblestone street" 95 | "a ${unique_token} ${class_token} on top of pink fabric" 96 | "a ${unique_token} ${class_token} on top of a wooden floor" 97 | "a ${unique_token} ${class_token} with a city in the background" 98 | "a ${unique_token} ${class_token} with a mountain in the background" 99 | "a ${unique_token} ${class_token} with a blue house in the background" 100 | "a ${unique_token} ${class_token} on top of a purple rug in a forest" 101 | "a ${unique_token} ${class_token} wearing a red hat" 102 | "a ${unique_token} ${class_token} wearing a santa hat" 103 | "a ${unique_token} ${class_token} wearing a rainbow scarf" 104 | "a ${unique_token} ${class_token} wearing a black top hat and a monocle" 105 | "a ${unique_token} ${class_token} in a chef outfit" 106 | "a ${unique_token} ${class_token} in a firefighter outfit" 107 | "a ${unique_token} ${class_token} in a police outfit" 108 | "a ${unique_token} ${class_token} wearing pink glasses" 109 | "a ${unique_token} ${class_token} wearing a yellow shirt" 110 | "a ${unique_token} ${class_token} in a purple wizard outfit" 111 | "a red ${unique_token} ${class_token}" 112 | "a purple ${unique_token} ${class_token}" 113 | "a shiny ${unique_token} ${class_token}" 114 | "a wet ${unique_token} ${class_token}" 115 | "a cube shaped ${unique_token} ${class_token}" 116 | ) 117 | 118 | prompt_test_list=( 119 | "a ${class_token} in the jungle" 120 | "a ${class_token} in the snow" 121 | "a ${class_token} on the beach" 122 | "a ${class_token} on a cobblestone street" 123 | "a ${class_token} on top of pink fabric" 124 | "a ${class_token} on top of a wooden floor" 125 | "a ${class_token} with a city in the background" 126 | "a ${class_token} with a mountain in the background" 127 | "a ${class_token} with a blue house in the background" 128 | "a ${class_token} on top of a purple rug in a forest" 129 | "a ${class_token} wearing a red hat" 130 | "a ${class_token} wearing a santa hat" 131 | "a ${class_token} wearing a rainbow scarf" 132 | "a ${class_token} wearing a black top hat and a monocle" 133 | "a ${class_token} in a chef outfit" 134 | "a ${class_token} in a firefighter outfit" 135 | "a ${class_token} in a police outfit" 136 | "a ${class_token} wearing pink glasses" 137 | "a ${class_token} wearing a yellow shirt" 138 | "a ${class_token} in a purple wizard outfit" 139 | "a red ${class_token}" 140 | "a purple ${class_token}" 141 | "a shiny ${class_token}" 142 | "a wet ${class_token}" 143 | "a cube shaped ${class_token}" 144 | ) 145 | fi 146 | 147 | 148 | validation_prompt=${prompt_list[$prompt_idx]} 149 | test_prompt=${prompt_test_list[$prompt_idx]} 150 | name="${selected_subject}-${prompt_idx}" 151 | instance_prompt="a photo of ${unique_token} ${class_token}" 152 | class_prompt="a photo of ${class_token}" 153 | 154 | export OUTPUT_DIR="log_quant/${name}" 155 | export INSTANCE_DIR="/data/guohang/dataset/dreambooth/${selected_subject}" 156 | export CLASS_DIR="/data/guohang/ControlNet/oft-db/data/class_data/${class_token}" 157 | 158 | 159 | 160 | python train_dreambooth_quant.py \ 161 | --pretrained_model_name_or_path=$MODEL_NAME \ 162 | --instance_data_dir=$INSTANCE_DIR \ 163 | --rank=4 \ 164 | --intlora=MUL \ 165 | --nbits=8 \ 166 | --use_activation_quant \ 167 | --act_nbits=8 \ 168 | --class_data_dir="$CLASS_DIR" \ 169 | --output_dir=$OUTPUT_DIR \ 170 | --instance_prompt="$instance_prompt" \ 171 | --with_prior_preservation \ 172 | --prior_loss_weight=1.0 \ 173 | --class_prompt="$class_prompt" \ 174 | --resolution=512 \ 175 | --train_batch_size=1 \ 176 | --gradient_accumulation_steps=1 \ 177 | --checkpointing_steps=5000 \ 178 | --learning_rate=6e-5 \ 179 | --report_to="wandb" \ 180 | --lr_scheduler="constant" \ 181 | --lr_warmup_steps=0 \ 182 | --max_train_steps=2000 \ 183 | --validation_prompt="$validation_prompt" \ 184 | --num_validation_images=5 \ 185 | --validation_epochs=1 \ 186 | --seed="0" \ 187 | --name="$name" \ 188 | --num_class_images=200 \ 189 | #--gradient_checkpointing #one can uncomment the GC to reduce the GPU memory 190 | done 191 | 192 | 193 | -------------------------------------------------------------------------------- /utils/intlora_mul.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from utils.quant_utils import batch_mse,batch_max,lp_loss,round_ste 6 | from utils.quant_layer import UniformAffineQuantizer 7 | 8 | 9 | 10 | class IntLoRA_MUL(nn.Module): 11 | # The implementation of our IntLoRA-MUL 12 | def __init__(self, org_module, n_bits=8, lora_bits=8, symmetric=False,channel_wise=True, rank=4, activation_params=None): 13 | super(IntLoRA_MUL, self).__init__() 14 | 15 | self.n_bits = n_bits 16 | self.lora_bits = lora_bits 17 | self.sym = symmetric 18 | self.scale_method = 'mse' 19 | self.always_zero = False 20 | self.n_levels = 2 ** self.n_bits if not self.sym else 2 ** (self.n_bits - 1) - 1 21 | self.channel_wise = channel_wise 22 | self.inited=False 23 | self.lora_levels = self.n_levels 24 | 25 | self.fwd_kwargs = dict() 26 | self.fwd_func = F.linear 27 | self.in_features = org_module.in_features 28 | self.out_features = org_module.out_features 29 | 30 | 31 | # save original weights and bias and keep them intact 32 | self.ori_weight_shape = org_module.weight.shape 33 | 34 | self.ori_weight = org_module.weight.view(self.out_features,-1).data.clone() # reshape here 35 | self.ori_bias = None if org_module.bias is None else org_module.bias.data.clone() 36 | 37 | self.act_quant_params = activation_params 38 | if self.act_quant_params is not None: 39 | self.act_quantizer = UniformAffineQuantizer(**self.act_quant_params) 40 | 41 | 42 | # quant lora quant here =========================== 43 | self.quant_lora_weights = True 44 | self.double_inited = False 45 | self.double_quant_delta = torch.nn.Parameter(torch.zeros(self.out_features,1)) 46 | self.double_quant_zero_point = torch.nn.Parameter(torch.zeros(self.out_features,1)) 47 | 48 | r = rank 49 | self.alpha = 1.5 50 | self.loraA = nn.Linear(org_module.in_features, r, bias=False) 51 | self.loraB = nn.Linear(r, org_module.out_features, bias=False) 52 | nn.init.kaiming_uniform_(self.loraA.weight, a=math.sqrt(5)) 53 | nn.init.zeros_(self.loraB.weight) 54 | 55 | # init the auxiliary matrix R in IntLoRA 56 | m = torch.distributions.laplace.Laplace(loc=torch.tensor([0.]),scale=torch.tensor([0.5])) 57 | aux_R = m.sample((org_module.out_features,org_module.in_features))[:,:,0] 58 | self.register_buffer('aux_R', aux_R) 59 | self.aux_R = self.aux_R.to(self.ori_weight.device).detach() 60 | 61 | 62 | def forward(self, input: torch.Tensor): 63 | if self.inited is False: 64 | aux_R_abs_max = torch.minimum(self.aux_R.max(dim=-1,keepdim=True)[0].abs(),self.aux_R.min(dim=-1,keepdim=True)[0].abs()).detach() 65 | ori_weight_abs_max = torch.maximum(self.ori_weight.max(dim=-1,keepdim=True)[0].abs(),self.ori_weight.min(dim=-1,keepdim=True)[0].abs()).detach() 66 | self.aux_R = ((ori_weight_abs_max)**self.alpha/(aux_R_abs_max+1e-8)**self.alpha)*self.aux_R 67 | 68 | ori_weight = self.ori_weight - self.aux_R 69 | delta, zero_point = self.init_quantization_scale(ori_weight, self.channel_wise,self.n_bits,self.sym) 70 | self.register_buffer('weight_quant_delta', delta) 71 | self.register_buffer('weight_quant_zero_point', zero_point) 72 | ori_weight_round = round_ste(ori_weight / self.weight_quant_delta) + self.weight_quant_zero_point 73 | if self.sym: 74 | ori_weight_round = torch.clamp(ori_weight_round, -self.n_levels - 1, self.n_levels) 75 | else: 76 | ori_weight_round = torch.clamp(ori_weight_round, 0, self.n_levels - 1) 77 | 78 | # delete the FP weights and save the int weights 79 | self.register_buffer('ori_weight_round', ori_weight_round) # int weight and keep it intact 80 | self.ori_weight = None 81 | torch.cuda.empty_cache() 82 | self.inited = True 83 | 84 | ori_weight_int = self.ori_weight_round - self.weight_quant_zero_point 85 | 86 | lora_weight = (self.aux_R + (self.loraB.weight @ self.loraA.weight)) / \ 87 | torch.where(ori_weight_int == 0, torch.tensor(1).to(ori_weight_int.device), ori_weight_int) 88 | weight_updates = self.weight_quant_delta + lora_weight # broad-cast 89 | 90 | 91 | if self.quant_lora_weights: 92 | if self.double_inited is False: 93 | delta, zero_point = self.init_quantization_scale(weight_updates, True, self.lora_bits) 94 | with torch.no_grad(): 95 | self.double_quant_delta.copy_(delta) 96 | self.double_quant_zero_point.copy_(zero_point) 97 | self.double_inited = True 98 | 99 | weight_updates_round = round_ste(weight_updates / self.double_quant_delta) + self.double_quant_zero_point 100 | if self.sym: 101 | weight_updates_round = torch.clamp(weight_updates_round, -self.lora_levels - 1, self.lora_levels) 102 | else: 103 | weight_updates_round = torch.clamp(weight_updates_round, 0, self.lora_levels - 1) 104 | 105 | weight_updates_int = (weight_updates_round - self.double_quant_zero_point) 106 | weight_int_mul = weight_updates_int * ori_weight_int # simulated INT multiply 107 | weight = self.double_quant_delta * weight_int_mul 108 | 109 | else: 110 | weight = weight_updates*ori_weight_int 111 | 112 | bias = self.ori_bias 113 | 114 | # do activation quantization 115 | if self.act_quant_params is not None: 116 | input = self.act_quantizer(input) 117 | 118 | out = self.fwd_func(input, weight, bias, **self.fwd_kwargs) 119 | 120 | return out 121 | 122 | 123 | 124 | def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False, n_bits: int = 8, sym: bool= False): 125 | n_levels = 2 ** n_bits if not sym else 2 ** (n_bits - 1) - 1 126 | delta, zero_point = None, None 127 | if channel_wise: 128 | x_clone = x.clone().detach() 129 | n_channels = x_clone.shape[0] 130 | if len(x.shape) == 4: 131 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0] 132 | elif len(x.shape) == 3: 133 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0] 134 | else: 135 | x_max = x_clone.abs().max(dim=-1)[0] 136 | delta = x_max.clone() 137 | zero_point = x_max.clone() 138 | # determine the scale and zero point channel-by-channel 139 | if 'max' in self.scale_method: 140 | delta, zero_point = batch_max(x_clone.view(n_channels, -1), sym, 2 ** n_bits, 141 | self.always_zero) 142 | 143 | elif 'mse' in self.scale_method: 144 | delta, zero_point = batch_mse(x_clone.view(n_channels, -1), sym, 2 ** n_bits, 145 | self.always_zero) 146 | 147 | if len(x.shape) == 4: 148 | delta = delta.view(-1, 1, 1, 1) 149 | zero_point = zero_point.view(-1, 1, 1, 1) 150 | elif len(x.shape) == 3: 151 | delta = delta.view(-1, 1, 1) 152 | zero_point = zero_point.view(-1, 1, 1) 153 | else: 154 | delta = delta.view(-1, 1) 155 | zero_point = zero_point.view(-1, 1) 156 | else: 157 | if 'max' in self.scale_method: 158 | x_min = min(x.min().item(), 0) 159 | x_max = max(x.max().item(), 0) 160 | if 'scale' in self.scale_method: 161 | x_min = x_min * (n_bits + 2) / 8 162 | x_max = x_max * (n_bits + 2) / 8 163 | 164 | x_absmax = max(abs(x_min), x_max) 165 | if sym: 166 | # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax 167 | delta = x_absmax / n_levels 168 | else: 169 | delta = float(x.max().item() - x.min().item()) / (n_levels - 1) 170 | if delta < 1e-8: 171 | delta = 1e-8 172 | 173 | zero_point = round(-x_min / delta) if not (sym or self.always_zero) else 0 174 | delta = torch.tensor(delta).type_as(x) 175 | 176 | elif self.scale_method == 'mse': 177 | x_max = x.max() 178 | x_min = x.min() 179 | best_score = 1e+10 180 | for i in range(80): 181 | new_max = x_max * (1.0 - (i * 0.01)) 182 | new_min = x_min * (1.0 - (i * 0.01)) 183 | x_q = self.quantize(x, new_max, new_min,n_bits,sym) 184 | score = lp_loss(x, x_q, p=2.4, reduction='all') 185 | if score < best_score: 186 | best_score = score 187 | delta = (new_max - new_min) / (2 ** n_bits - 1) \ 188 | if not self.always_zero else new_max / (2 ** n_bits - 1) 189 | zero_point = (- new_min / delta).round() if not self.always_zero else 0 190 | else: 191 | raise NotImplementedError 192 | 193 | return delta, zero_point 194 | 195 | def quantize(self, x, max, min,n_bits,sym): 196 | n_levels = 2 ** n_bits if not sym else 2 ** (n_bits - 1) - 1 197 | delta = (max - min) / (2 ** n_bits - 1) if not self.always_zero else max / (2 ** n_bits - 1) 198 | zero_point = (- min / delta).round() if not self.always_zero else 0 199 | x_int = torch.round(x / delta) 200 | x_quant = torch.clamp(x_int + zero_point, 0,n_levels - 1) 201 | x_float_q = (x_quant - zero_point) * delta 202 | return x_float_q 203 | -------------------------------------------------------------------------------- /utils/intlora_shift.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import math 5 | from utils.quant_utils import batch_mse, batch_max, lp_loss, round_ste 6 | import numpy as np 7 | from utils.quant_layer import UniformAffineQuantizer 8 | 9 | 10 | def round(x, rounding='deterministic'): 11 | assert (rounding in ['deterministic', 'stochastic']) 12 | if rounding == 'stochastic': 13 | x_floor = x.floor() 14 | return x_floor + torch.bernoulli(x - x_floor) 15 | else: 16 | return x.round() 17 | 18 | 19 | def get_shift_and_sign(x, rounding='deterministic'): 20 | sign = torch.sign(x) 21 | 22 | x_abs = torch.abs(x) 23 | if rounding == "floor": 24 | shift = torch.floor(torch.log(x_abs) / np.log(2)) 25 | else: 26 | shift = round_ste(torch.log(x_abs) / np.log(2)) 27 | 28 | return shift, sign 29 | 30 | 31 | def round_power_of_2(x, rounding='deterministic', q_bias=None, scale=None): 32 | if q_bias is not None: 33 | q_bias = q_bias.unsqueeze(1).expand_as(x) 34 | x = x - q_bias 35 | if scale is not None: 36 | scale = scale.unsqueeze(1).expand_as(x) 37 | x = x / scale 38 | shift, sign = get_shift_and_sign(x, rounding) 39 | x_rounded = (2.0 ** shift) * sign 40 | if scale is not None: 41 | x_rounded = x_rounded * scale 42 | if q_bias is not None: 43 | x_rounded = x_rounded + q_bias 44 | return x_rounded 45 | 46 | 47 | def additive_power_of_2(x, log_s): 48 | sign = torch.sign(x) 49 | x_abs = torch.abs(x) 50 | 51 | shift = round_ste(torch.log(x_abs) / np.log(2) + log_s) 52 | 53 | x_rounded = (2.0 ** shift) * sign 54 | 55 | return x_rounded 56 | 57 | 58 | class StraightThrough(nn.Module): 59 | def __init__(self, channel_num: int = 1): 60 | super().__init__() 61 | 62 | def forward(self, input): 63 | return input 64 | 65 | 66 | 67 | 68 | class IntLoRA_SHIFT(nn.Module): 69 | # The implementation of our IntLoRA-SHIFT 70 | def __init__(self, org_module, n_bits=8, lora_bits=8, symmetric=False, channel_wise=True, rank=4, activation_params=None): 71 | super(IntLoRA_SHIFT, self).__init__() 72 | 73 | self.n_bits = n_bits 74 | self.lora_bits = lora_bits 75 | self.sym = symmetric 76 | self.scale_method = 'mse' 77 | self.always_zero = False 78 | self.n_levels = 2 ** self.n_bits if not self.sym else 2 ** (self.n_bits - 1) - 1 79 | self.channel_wise = channel_wise 80 | self.inited = False 81 | self.lora_levels = self.n_levels 82 | 83 | self.fwd_kwargs = dict() 84 | self.fwd_func = F.linear 85 | self.in_features = org_module.in_features 86 | self.out_features = org_module.out_features 87 | 88 | # save original weights and bias and keep them intact 89 | self.ori_weight_shape = org_module.weight.shape 90 | 91 | self.ori_weight = org_module.weight.view(self.out_features, -1).data.clone() # reshape here 92 | self.ori_bias = None if org_module.bias is None else org_module.bias.data.clone() 93 | 94 | self.act_quant_params = activation_params 95 | if self.act_quant_params is not None: 96 | self.act_quantizer = UniformAffineQuantizer(**self.act_quant_params) 97 | 98 | # quant lora quant here =========================== 99 | self.quant_lora_weights = True 100 | r= rank 101 | self.alpha = 1.5 102 | 103 | self.loraA = nn.Linear(org_module.in_features, r, bias=False) 104 | self.loraB = nn.Linear(r, org_module.out_features, bias=False) 105 | nn.init.kaiming_uniform_(self.loraA.weight, a=math.sqrt(5)) 106 | nn.init.zeros_(self.loraB.weight) 107 | 108 | # init the auxiliary matrix R in IntLoRA 109 | m = torch.distributions.laplace.Laplace(loc=torch.tensor([0.]),scale=torch.tensor([0.5])) 110 | aux_R = m.sample((org_module.out_features,org_module.in_features))[:,:,0] 111 | self.register_buffer('aux_R', aux_R) 112 | self.aux_R = self.aux_R.to(self.ori_weight.device).detach() 113 | 114 | 115 | def forward(self, input: torch.Tensor): 116 | if self.inited is False: 117 | aux_R_abs_max = torch.minimum(self.aux_R.max(dim=-1, keepdim=True)[0].abs(), 118 | self.aux_R.min(dim=-1, keepdim=True)[0].abs()).detach() 119 | ori_weight_abs_max = torch.maximum(self.ori_weight.max(dim=-1, keepdim=True)[0].abs(), 120 | self.ori_weight.min(dim=-1, keepdim=True)[0].abs()).detach() 121 | self.aux_R = ((ori_weight_abs_max) ** self.alpha / (aux_R_abs_max + 1e-8) ** self.alpha) * self.aux_R 122 | 123 | ori_weight = self.ori_weight - self.aux_R 124 | delta, zero_point = self.init_quantization_scale(ori_weight, self.channel_wise, self.n_bits, self.sym) 125 | self.register_buffer('weight_quant_delta', delta) 126 | self.register_buffer('weight_quant_zero_point', zero_point) 127 | ori_weight_round = round_ste(ori_weight / self.weight_quant_delta) + self.weight_quant_zero_point 128 | if self.sym: 129 | ori_weight_round = torch.clamp(ori_weight_round, -self.n_levels - 1, self.n_levels) 130 | else: 131 | ori_weight_round = torch.clamp(ori_weight_round, 0, self.n_levels - 1) 132 | 133 | # delete the FP weights and save the int weights 134 | self.register_buffer('ori_weight_round', ori_weight_round) # int weight and keep it intact 135 | self.ori_weight = None 136 | torch.cuda.empty_cache() 137 | self.inited = True 138 | 139 | ori_weight_int = self.ori_weight_round - self.weight_quant_zero_point 140 | 141 | # PETL for quant scale here ================================== 142 | lora_weight = (self.aux_R + (self.loraB.weight @ self.loraA.weight)) / \ 143 | torch.where(ori_weight_int == 0, torch.tensor(1).to(ori_weight_int.device), ori_weight_int) 144 | weight_updates = self.weight_quant_delta + lora_weight # broad-cast 145 | 146 | 147 | if self.quant_lora_weights: 148 | weight_updates_sign = weight_updates.sign() 149 | weight_updates_abs = torch.abs(weight_updates) 150 | lora_shift = round_ste(torch.log2(weight_updates_abs + 1e-16)) 151 | lora_rounded = 2.0 ** lora_shift 152 | weight = weight_updates_sign * lora_rounded * ori_weight_int 153 | if torch.any(torch.isnan(weight_updates)): 154 | print('There is nan in the weight-updates for log2 quantization') 155 | raise NotImplementedError 156 | 157 | else: 158 | weight = weight_updates * ori_weight_int 159 | 160 | # do activation quantization 161 | if self.act_quant_params is not None: 162 | input = self.act_quantizer(input) 163 | 164 | bias = self.ori_bias 165 | out = self.fwd_func(input, weight, bias, **self.fwd_kwargs) 166 | 167 | return out 168 | 169 | def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False, n_bits: int = 8, sym: bool = False): 170 | n_levels = 2 ** n_bits if not sym else 2 ** (n_bits - 1) - 1 171 | delta, zero_point = None, None 172 | if channel_wise: 173 | x_clone = x.clone().detach() 174 | n_channels = x_clone.shape[0] 175 | if len(x.shape) == 4: 176 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0] 177 | elif len(x.shape) == 3: 178 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0] 179 | else: 180 | x_max = x_clone.abs().max(dim=-1)[0] 181 | delta = x_max.clone() 182 | zero_point = x_max.clone() 183 | # determine the scale and zero point channel-by-channel 184 | if 'max' in self.scale_method: 185 | delta, zero_point = batch_max(x_clone.view(n_channels, -1), sym, 2 ** n_bits, 186 | self.always_zero) 187 | 188 | elif 'mse' in self.scale_method: 189 | delta, zero_point = batch_mse(x_clone.view(n_channels, -1), sym, 2 ** n_bits, 190 | self.always_zero) 191 | 192 | if len(x.shape) == 4: 193 | delta = delta.view(-1, 1, 1, 1) 194 | zero_point = zero_point.view(-1, 1, 1, 1) 195 | elif len(x.shape) == 3: 196 | delta = delta.view(-1, 1, 1) 197 | zero_point = zero_point.view(-1, 1, 1) 198 | else: 199 | delta = delta.view(-1, 1) 200 | zero_point = zero_point.view(-1, 1) 201 | else: 202 | # if self.leaf_param: 203 | # self.x_min = x.data.min() 204 | # self.x_max = x.data.max() 205 | 206 | if 'max' in self.scale_method: 207 | x_min = min(x.min().item(), 0) 208 | x_max = max(x.max().item(), 0) 209 | if 'scale' in self.scale_method: 210 | x_min = x_min * (n_bits + 2) / 8 211 | x_max = x_max * (n_bits + 2) / 8 212 | 213 | x_absmax = max(abs(x_min), x_max) 214 | if sym: 215 | # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax 216 | delta = x_absmax / n_levels 217 | else: 218 | delta = float(x.max().item() - x.min().item()) / (n_levels - 1) 219 | if delta < 1e-8: 220 | delta = 1e-8 221 | 222 | zero_point = round(-x_min / delta) if not (sym or self.always_zero) else 0 223 | delta = torch.tensor(delta).type_as(x) 224 | 225 | elif self.scale_method == 'mse': 226 | x_max = x.max() 227 | x_min = x.min() 228 | best_score = 1e+10 229 | for i in range(80): 230 | new_max = x_max * (1.0 - (i * 0.01)) 231 | new_min = x_min * (1.0 - (i * 0.01)) 232 | x_q = self.quantize(x, new_max, new_min, n_bits, sym) 233 | score = lp_loss(x, x_q, p=2.4, reduction='all') 234 | if score < best_score: 235 | best_score = score 236 | delta = (new_max - new_min) / (2 ** n_bits - 1) \ 237 | if not self.always_zero else new_max / (2 ** n_bits - 1) 238 | zero_point = (- new_min / delta).round() if not self.always_zero else 0 239 | else: 240 | raise NotImplementedError 241 | 242 | return delta, zero_point 243 | 244 | def quantize(self, x, max, min, n_bits, sym): 245 | n_levels = 2 ** n_bits if not sym else 2 ** (n_bits - 1) - 1 246 | delta = (max - min) / (2 ** n_bits - 1) if not self.always_zero else max / (2 ** n_bits - 1) 247 | zero_point = (- min / delta).round() if not self.always_zero else 0 248 | # we assume weight quantization is always signed 249 | x_int = torch.round(x / delta) 250 | x_quant = torch.clamp(x_int + zero_point, 0, n_levels - 1) 251 | x_float_q = (x_quant - zero_point) * delta 252 | return x_float_q 253 | -------------------------------------------------------------------------------- /utils/quant_layer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from typing import Union 7 | from utils.quant_utils import batch_mse,batch_max 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class StraightThrough(nn.Module): 12 | def __init__(self, channel_num: int = 1): 13 | super().__init__() 14 | 15 | def forward(self, input): 16 | return input 17 | 18 | 19 | def round_ste(x: torch.Tensor): 20 | """ 21 | Implement Straight-Through Estimator for rounding operation. 22 | """ 23 | return (x.round() - x).detach() + x 24 | 25 | 26 | def lp_loss(pred, tgt, p=2.0, reduction='none'): 27 | """ 28 | loss function measured in L_p Norm 29 | """ 30 | if reduction == 'none': 31 | return (pred-tgt).abs().pow(p).sum(1).mean() 32 | else: 33 | return (pred-tgt).abs().pow(p).mean() 34 | 35 | 36 | 37 | class UniformAffineQuantizer(nn.Module): 38 | """ 39 | PyTorch Function that can be used for asymmetric quantization (also called uniform affine 40 | quantization). Quantizes its argument in the forward pass, passes the gradient 'straight 41 | through' on the backward pass, ignoring the quantization that occurred. 42 | Based on https://arxiv.org/abs/1806.08342. 43 | 44 | :param n_bits: number of bit for quantization 45 | :param symmetric: if True, the zero_point should always be 0 46 | :param channel_wise: if True, compute scale and zero_point in each channel 47 | :param scale_method: determines the quantization scale and zero point 48 | """ 49 | 50 | def __init__(self, n_bits: int = 8, symmetric: bool = False, channel_wise: bool = False, scale_method: str = 'max', 51 | leaf_param: bool = False, always_zero: bool = False, **kwargs): 52 | super(UniformAffineQuantizer, self).__init__() 53 | self.sym = symmetric 54 | # assert 2 <= n_bits <= 8, 'bitwidth not supported' 55 | self.n_bits = n_bits 56 | self.n_levels = 2 ** self.n_bits if not self.sym else 2 ** (self.n_bits - 1) - 1 57 | self.delta = None 58 | self.zero_point = None 59 | self.inited = False 60 | self.leaf_param = leaf_param 61 | self.channel_wise = channel_wise 62 | self.scale_method = scale_method 63 | self.running_stat = False 64 | self.always_zero = always_zero 65 | if self.leaf_param: 66 | self.x_min, self.x_max = None, None 67 | 68 | def forward(self, x: torch.Tensor): 69 | if self.inited is False: 70 | if self.leaf_param: 71 | delta, zero_point = self.init_quantization_scale(x.detach(), self.channel_wise) 72 | self.delta = torch.nn.Parameter(delta) 73 | self.zero_point = torch.nn.Parameter(zero_point) 74 | 75 | else: 76 | self.delta, self.zero_point = self.init_quantization_scale(x, self.channel_wise) 77 | self.inited = True 78 | 79 | # start quantization 80 | # print(f"x shape {x.shape} delta shape {self.delta.shape} zero shape {self.zero_point.shape}") 81 | x_int = round_ste(x / self.delta) + self.zero_point 82 | # x_quant = torch.clamp(x_int, 0, self.n_levels - 1) 83 | if self.sym: 84 | x_quant = torch.clamp(x_int, -self.n_levels - 1, self.n_levels) 85 | else: 86 | x_quant = torch.clamp(x_int, 0, self.n_levels - 1) 87 | x_dequant = (x_quant - self.zero_point) * self.delta 88 | return x_dequant 89 | 90 | def act_momentum_update(self, x: torch.Tensor, act_range_momentum: float = 0.95): 91 | assert (self.inited) 92 | assert (self.leaf_param) 93 | 94 | x_min = x.data.min() 95 | x_max = x.data.max() 96 | self.x_min = self.x_min * act_range_momentum + x_min * (1 - act_range_momentum) 97 | self.x_max = self.x_max * act_range_momentum + x_max * (1 - act_range_momentum) 98 | 99 | if self.sym: 100 | # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax 101 | delta = torch.max(self.x_min.abs(), self.x_max.abs()) / self.n_levels 102 | else: 103 | delta = (self.x_max - self.x_min) / (self.n_levels - 1) if not self.always_zero \ 104 | else self.x_max / (self.n_levels - 1) 105 | 106 | delta = torch.clamp(delta, min=1e-8) 107 | if not self.sym: 108 | self.zero_point = (-self.x_min / delta).round() if not (self.sym or self.always_zero) else 0 109 | self.delta = torch.nn.Parameter(delta) 110 | 111 | def init_quantization_scale(self, x: torch.Tensor, channel_wise: bool = False): 112 | delta, zero_point = None, None 113 | if channel_wise: 114 | x_clone = x.clone().detach() 115 | n_channels = x_clone.shape[0] 116 | if len(x.shape) == 4: 117 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0] 118 | elif len(x.shape) == 3: 119 | x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0] 120 | else: 121 | x_max = x_clone.abs().max(dim=-1)[0] 122 | delta = x_max.clone() 123 | zero_point = x_max.clone() 124 | # determine the scale and zero point channel-by-channel 125 | if 'max' in self.scale_method: 126 | delta, zero_point = batch_max(x_clone.view(n_channels, -1), self.sym, 2 ** self.n_bits, 127 | self.always_zero) 128 | 129 | elif 'mse' in self.scale_method: 130 | delta, zero_point = batch_mse(x_clone.view(n_channels, -1), self.sym, 2 ** self.n_bits, 131 | self.always_zero) 132 | 133 | if len(x.shape) == 4: 134 | delta = delta.view(-1, 1, 1, 1) 135 | zero_point = zero_point.view(-1, 1, 1, 1) 136 | elif len(x.shape) == 3: 137 | delta = delta.view(-1, 1, 1) 138 | zero_point = zero_point.view(-1, 1, 1) 139 | else: 140 | delta = delta.view(-1, 1) 141 | zero_point = zero_point.view(-1, 1) 142 | else: 143 | if self.leaf_param: # for momentum update 144 | self.x_min = x.data.min() 145 | self.x_max = x.data.max() 146 | 147 | if 'max' in self.scale_method: 148 | x_min = min(x.min().item(), 0) 149 | x_max = max(x.max().item(), 0) 150 | if 'scale' in self.scale_method: 151 | x_min = x_min * (self.n_bits + 2) / 8 152 | x_max = x_max * (self.n_bits + 2) / 8 153 | 154 | x_absmax = max(abs(x_min), x_max) 155 | if self.sym: 156 | # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax 157 | delta = x_absmax / self.n_levels 158 | else: 159 | delta = float(x.max().item() - x.min().item()) / (self.n_levels - 1) 160 | if delta < 1e-8: 161 | warnings.warn('Quantization range close to zero: [{}, {}]'.format(x_min, x_max)) 162 | delta = 1e-8 163 | 164 | zero_point = round(-x_min / delta) if not (self.sym or self.always_zero) else 0 165 | delta = torch.tensor(delta).type_as(x) 166 | 167 | elif self.scale_method == 'mse': 168 | x_max = x.max() 169 | x_min = x.min() 170 | best_score = 1e+10 171 | for i in range(80): 172 | new_max = x_max * (1.0 - (i * 0.01)) 173 | new_min = x_min * (1.0 - (i * 0.01)) 174 | x_q = self.quantize(x, new_max, new_min) 175 | # L_p norm minimization as described in LAPQ 176 | # https://arxiv.org/abs/1911.07190 177 | score = lp_loss(x, x_q, p=2.4, reduction='all') 178 | if score < best_score: 179 | best_score = score 180 | delta = (new_max - new_min) / (2 ** self.n_bits - 1) \ 181 | if not self.always_zero else new_max / (2 ** self.n_bits - 1) 182 | zero_point = (- new_min / delta).round() if not self.always_zero else 0 183 | else: 184 | raise NotImplementedError 185 | 186 | return delta, zero_point 187 | 188 | def quantize(self, x, max, min): 189 | delta = (max - min) / (2 ** self.n_bits - 1) if not self.always_zero else max / (2 ** self.n_bits - 1) 190 | zero_point = (- min / delta).round() if not self.always_zero else 0 191 | # we assume weight quantization is always signed 192 | x_int = torch.round(x / delta) 193 | x_quant = torch.clamp(x_int + zero_point, 0, self.n_levels - 1) 194 | x_float_q = (x_quant - zero_point) * delta 195 | return x_float_q 196 | 197 | def bitwidth_refactor(self, refactored_bit: int): 198 | # assert 2 <= refactored_bit <= 8, 'bitwidth not supported' 199 | self.n_bits = refactored_bit 200 | self.n_levels = 2 ** self.n_bits 201 | 202 | def extra_repr(self): 203 | s = 'bit={n_bits}, scale_method={scale_method}, symmetric={sym}, channel_wise={channel_wise},' \ 204 | ' leaf_param={leaf_param}' 205 | return s.format(**self.__dict__) 206 | 207 | 208 | class QuantLayerNormal(nn.Module): 209 | """ 210 | Quantized Module that can perform quantized convolution or normal convolution. 211 | To activate quantization, please use set_quant_state function. 212 | """ 213 | def __init__(self, org_module: Union[nn.Conv2d, nn.Linear, nn.Conv1d], weight_quant_params= None, 214 | activation_params= None,): 215 | super(QuantLayerNormal, self).__init__() 216 | 217 | self.weight_quant_params = weight_quant_params 218 | self.act_quant_params = activation_params 219 | 220 | if isinstance(org_module, nn.Conv2d): 221 | self.fwd_kwargs = dict(stride=org_module.stride, padding=org_module.padding, 222 | dilation=org_module.dilation, groups=org_module.groups) 223 | self.fwd_func = F.conv2d 224 | elif isinstance(org_module, nn.Conv1d): 225 | self.fwd_kwargs = dict(stride=org_module.stride, padding=org_module.padding, 226 | dilation=org_module.dilation, groups=org_module.groups) 227 | self.fwd_func = F.conv1d 228 | else: 229 | self.fwd_kwargs = dict() 230 | self.fwd_func = F.linear 231 | self.weight = org_module.weight 232 | self.org_weight = org_module.weight.data 233 | if org_module.bias is not None: 234 | self.bias = org_module.bias 235 | self.org_bias = org_module.bias.data 236 | else: 237 | self.bias = None 238 | self.org_bias = None 239 | # de-activate the quantized forward default 240 | self.use_weight_quant = self.weight_quant_params is not None 241 | self.use_act_quant = self.act_quant_params is not None 242 | 243 | 244 | # initialize quantizer 245 | self.weight_quantizer = UniformAffineQuantizer(**self.weight_quant_params) 246 | if self.use_act_quant: 247 | self.act_quantizer = UniformAffineQuantizer(**self.act_quant_params) 248 | self.activation_function = StraightThrough() 249 | self.extra_repr = org_module.extra_repr 250 | 251 | def forward(self, input: torch.Tensor): 252 | if self.use_act_quant: # Activation Quant 253 | input = self.act_quantizer(input) 254 | 255 | if self.use_weight_quant: # Weight Quant 256 | with torch.no_grad(): 257 | weight = self.weight_quantizer(self.weight) 258 | bias = self.bias 259 | else: 260 | weight = self.org_weight 261 | bias = self.org_bias 262 | out = self.fwd_func(input, weight, bias, **self.fwd_kwargs) 263 | out = self.activation_function(out) 264 | 265 | return out 266 | -------------------------------------------------------------------------------- /utils/quant_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from utils.intlora_mul import IntLoRA_MUL 5 | from utils.intlora_shift import IntLoRA_SHIFT 6 | from utils.quant_layer import QuantLayerNormal 7 | 8 | 9 | 10 | class StraightThrough(nn.Module): 11 | def __init__(self, channel_num: int = 1): 12 | super().__init__() 13 | 14 | def forward(self, input): 15 | return input 16 | 17 | 18 | 19 | def check_in_special(ss, special_list): 20 | for itm in special_list: 21 | if itm in ss: 22 | return True 23 | return False 24 | 25 | 26 | 27 | class QuantUnetWarp(nn.Module): 28 | def __init__(self, model: nn.Module, args): 29 | super().__init__() 30 | self.model = model 31 | self.intlora_type = args.intlora 32 | lora_quant_params = {'n_bits': args.nbits, 'lora_bits':args.nbits, 'symmetric':False, 'channel_wise':True, 'rank':args.rank} 33 | other_quant_params = {'n_bits': args.nbits, 'symmetric': False, 'channel_wise': True, 'scale_method': 'mse'} 34 | activation_quant_params = {'n_bits': args.act_nbits, 'symmetric': False, 'channel_wise': False, 'scale_method': 'mse','leaf_param': True} if args.use_activation_quant else None 35 | special_list = ['to_q','to_k','to_v','to_out'] 36 | assert self.intlora_type in ['MUL','SHIFT'] 37 | self.IntLoRALayer = IntLoRA_MUL if self.intlora_type == 'MUL' else IntLoRA_SHIFT 38 | 39 | self.quant_module_refactor(self.model, lora_quant_params, other_quant_params, activation_quant_params, special_list) 40 | 41 | 42 | def quant_module_refactor(self, module: nn.Module, lora_quant_params, other_quant_params, activation_quant_params, sepcial_list, prev_name=''): 43 | for name, child_module in module.named_children(): 44 | tmp_name = prev_name+'_'+name 45 | if isinstance(child_module, (nn.Conv2d, nn.Conv1d, nn.Linear)) and not ('downsample' in name and 'conv' in name): 46 | if check_in_special(tmp_name,sepcial_list): 47 | setattr(module, name, self.IntLoRALayer(child_module, **lora_quant_params, activation_params=activation_quant_params)) 48 | else: 49 | setattr(module, name, QuantLayerNormal(child_module, other_quant_params,activation_params=activation_quant_params)) 50 | elif isinstance(child_module, StraightThrough): 51 | continue 52 | else: 53 | self.quant_module_refactor(child_module, lora_quant_params, other_quant_params,activation_quant_params, sepcial_list, prev_name=tmp_name) 54 | 55 | 56 | def forward(self, image, t, context=None): 57 | return self.model(image, t, context) 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /utils/quant_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from diffusers.utils import deprecate, logging 5 | from diffusers.utils.import_utils import is_xformers_available 6 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 7 | 8 | if is_xformers_available(): 9 | import xformers 10 | import xformers.ops 11 | else: 12 | xformers = None 13 | 14 | 15 | def batch_mse(x: torch.Tensor, 16 | symmetric: bool = False, 17 | level: int = 256, 18 | always_zero: bool = False 19 | ) -> [torch.Tensor, torch.Tensor]: 20 | x_min, x_max = torch.min(x, dim=-1, keepdim=True)[0], torch.max(x, dim=-1, keepdim=True)[0] # [d_out] 21 | delta, zero_point = torch.zeros_like(x_min), torch.zeros_like(x_min) 22 | s = torch.full((x.shape[0], 1), 30000, dtype=x.dtype, device=x.device) 23 | for i in range(80): 24 | new_min = x_min * (1. - (i * 0.01)) 25 | new_max = x_max * (1. - (i * 0.01)) 26 | new_delta = (new_max - new_min) / (level - 1) 27 | new_zero_point = torch.round(-new_min / new_delta) if not (symmetric or always_zero) else torch.zeros_like(new_delta) 28 | NB, PB = -level // 2 if symmetric and not always_zero else 0, \ 29 | level // 2 - 1 if symmetric and not always_zero else level - 1 30 | x_q = torch.clamp(torch.round(x / new_delta) + new_zero_point, NB, PB) 31 | x_dq = new_delta * (x_q - new_zero_point) 32 | new_s = (x_dq - x).abs().pow(2.4).mean(dim=-1, keepdim=True) 33 | 34 | update_mask = new_s < s 35 | delta[update_mask] = new_delta[update_mask] 36 | zero_point[update_mask] = new_zero_point[update_mask] 37 | s[update_mask] = new_s[update_mask] 38 | 39 | return delta.squeeze(), zero_point.squeeze() 40 | 41 | 42 | def batch_max(x: torch.Tensor, 43 | symmetric: bool = False, 44 | level: int = 256, 45 | always_zero: bool = False 46 | ) -> [torch.Tensor, torch.Tensor]: 47 | x_min, x_max = torch.min(x, dim=-1, keepdim=True)[0], torch.max(x, dim=-1, keepdim=True)[0] # [d_out] 48 | 49 | x_absmax = torch.max(torch.cat([x_min.abs(), x_max], dim=-1), dim=-1, keepdim=True)[0] 50 | 51 | if symmetric: 52 | # x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax 53 | delta = x_absmax / level 54 | else: 55 | delta = (x_max - x_min) / (level - 1) 56 | 57 | delta = torch.clamp(delta, min=1e-8) 58 | zero_point = torch.round(-x_min / delta) if not (symmetric or always_zero) else 0 59 | 60 | return delta.squeeze(), zero_point.squeeze() 61 | 62 | 63 | 64 | def round_ste(x: torch.Tensor): 65 | return (x.round() - x).detach() + x 66 | 67 | 68 | 69 | def lp_loss(pred, tgt, p=2.0, reduction='none'): 70 | if reduction == 'none': 71 | return (pred-tgt).abs().pow(p).sum(1).mean() 72 | else: 73 | return (pred-tgt).abs().pow(p).mean() 74 | 75 | 76 | 77 | --------------------------------------------------------------------------------