├── .gitignore ├── LICENSE ├── README.md ├── assets ├── example_inputs │ ├── health.jpg │ └── mask.png ├── paper │ ├── pie.jpg │ └── pipeline.jpg └── progression │ ├── confidence.png │ └── progression.gif ├── checkpoints └── checkpoints.txt ├── metric └── confidence_model │ ├── README.md │ ├── finetune.py │ └── pretrain.py ├── pipeline_stable_diffusion_pie.py ├── plot_confidence.py ├── requirements.txt ├── run_pie.py └── training ├── README.md ├── chexpert.py ├── retinopahty.py ├── skin.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Anonymous 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PIE: Simulating Disease Progression via Progressive Image Editing 2 | 3 | [[Project Homepage](https://www.irohxucao.com/PIE/)] | [[Preprint](https://arxiv.org/abs/2309.11745)] | [[HuggingFace](https://huggingface.co/papers/2309.11745)] 4 | 5 | Official Implementation of "Simulating Disease Progression via Progressive Image Editing". 6 | 7 | 8 | ![](./assets/paper/pie.jpg) 9 | 10 | Disease progression simulation is a crucial area of research that has significant implications for clinical diagnosis, prognosis, and treatment. One major challenge in this field is the lack of continuous medical imaging monitoring of individual patients over time. To address this issue, we develop a novel framework termed Progressive Image Editing (PIE) that enables controlled manipulation of disease-related image features, facilitating precise and realistic disease progression simulation. Specifically, we leverage recent advancements in text-to-image generative models to simulate disease progression accurately and personalize it for each patient. 11 | 12 | To our best knowledge, PIE is the first of its kind to generate disease progression images meeting real-world standards. It is a promising tool for medical research and clinical practice, potentially allowing healthcare providers to model disease trajectories over time, predict future treatment responses, and improve patient outcomes. 13 | 14 | ![](./assets/progression/progression.gif) 15 | 16 | ## Requirements 17 | 18 | Install the newest PyTorch. 19 | 20 | ``` 21 | conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia 22 | ``` 23 | 24 | ``` 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## Inference 29 | 30 | 31 | ### Sampling Script 32 | 33 | Try new pretrained weight from MIMIC-CXR dataset [here](https://huggingface.co/IrohXu/stable-diffusion-mimic-cxr-v0.1) 34 | 35 | ``` 36 | python run_pie.py \ 37 | --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ 38 | --finetuned_path="path-to-finetune-stable-diffusion-checkpoint" \ 39 | --image_path="./assets/example_inputs/health.jpg" \ 40 | --mask_path="./assets/example_inputs/mask.png" \ 41 | --prompt="clinical-reports-about-any-diseases" \ 42 | --step=10 \ 43 | --strength=0.5 \ 44 | --guidance_scale=27.5 \ 45 | --seed=42 \ 46 | --resolution=512 47 | ``` 48 | 49 | ## Reference 50 | 51 | ``` 52 | @misc{liang2023pie, 53 | title={PIE: Simulating Disease Progression via Progressive Image Editing}, 54 | author={Kaizhao Liang and Xu Cao and Kuei-Da Liao and Tianren Gao and Wenqian Ye and Zhengyu Chen and Jianguo Cao and Tejas Nama and Jimeng Sun}, 55 | year={2023}, 56 | eprint={2309.11745}, 57 | archivePrefix={arXiv}, 58 | primaryClass={eess.IV} 59 | } 60 | ``` 61 | 62 | ## Development Timeline 63 | 64 | `10/01/2023` Release a new pretrained weight for MIMIC-CXR dataset [here](https://huggingface.co/IrohXu/stable-diffusion-mimic-cxr-v0.1). 65 | `11/15/2023` Embed PIE with GPT-4V or LLaVA and release checkpoint demo. Update inference pipeline. 66 | `11/28/2023` Kaizhao Liang and Wenqian Ye will present PIE at [UVa AIML Seminar](https://uvaml.github.io/). 67 | `2024` Release PIE v2 and new checkpoint. 68 | 69 | 70 | 71 | -------------------------------------------------------------------------------- /assets/example_inputs/health.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/PIE/3a36495b4c933c89a7b14bce7857b04c9afcb307/assets/example_inputs/health.jpg -------------------------------------------------------------------------------- /assets/example_inputs/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/PIE/3a36495b4c933c89a7b14bce7857b04c9afcb307/assets/example_inputs/mask.png -------------------------------------------------------------------------------- /assets/paper/pie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/PIE/3a36495b4c933c89a7b14bce7857b04c9afcb307/assets/paper/pie.jpg -------------------------------------------------------------------------------- /assets/paper/pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/PIE/3a36495b4c933c89a7b14bce7857b04c9afcb307/assets/paper/pipeline.jpg -------------------------------------------------------------------------------- /assets/progression/confidence.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/PIE/3a36495b4c933c89a7b14bce7857b04c9afcb307/assets/progression/confidence.png -------------------------------------------------------------------------------- /assets/progression/progression.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PediaMedAI/PIE/3a36495b4c933c89a7b14bce7857b04c9afcb307/assets/progression/progression.gif -------------------------------------------------------------------------------- /checkpoints/checkpoints.txt: -------------------------------------------------------------------------------- 1 | Put Stable Diffusion checkpoints here. -------------------------------------------------------------------------------- /metric/confidence_model/README.md: -------------------------------------------------------------------------------- 1 | The code is copy from the LibAUC code repo. -------------------------------------------------------------------------------- /metric/confidence_model/finetune.py: -------------------------------------------------------------------------------- 1 | from libauc.losses import AUCMLoss, CrossEntropyLoss 2 | from libauc.optimizers import PESG, Adam 3 | from libauc.models import densenet121 as DenseNet121 4 | from libauc.datasets import CheXpert 5 | 6 | import torch 7 | from PIL import Image 8 | import numpy as np 9 | import torchvision.transforms as transforms 10 | from torch.utils.data import Dataset 11 | from sklearn.metrics import roc_auc_score 12 | 13 | def set_all_seeds(SEED): 14 | # REPRODUCIBILITY 15 | torch.manual_seed(SEED) 16 | np.random.seed(SEED) 17 | torch.backends.cudnn.deterministic = True 18 | torch.backends.cudnn.benchmark = False 19 | 20 | # parameters 21 | class_id = 0 # 0:Cardiomegaly, 1:Edema, 2:Consolidation, 3:Atelectasis, 4:Pleural Effusion 22 | root = './CheXpert-v1.0' 23 | 24 | # paramaters 25 | SEED = 123 26 | BATCH_SIZE = 128 27 | lr = 0.01 # using smaller learning rate is better 28 | epoch_decay = 2e-3 29 | weight_decay = 1e-5 30 | margin = 1.0 31 | 32 | # You can set use_upsampling=True and pass the class name by upsampling_cols=['Cardiomegaly'] to do upsampling. This may improve the performance 33 | traindSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, use_upsampling=True, use_frontal=True, image_size=224, mode='train', class_index=class_id) 34 | testSet = CheXpert(csv_path=root+'valid.csv', image_root_path=root, use_upsampling=False, use_frontal=True, image_size=224, mode='valid', class_index=class_id) 35 | trainloader = torch.utils.data.DataLoader(traindSet, batch_size=BATCH_SIZE, num_workers=2, shuffle=True) 36 | testloader = torch.utils.data.DataLoader(testSet, batch_size=BATCH_SIZE, num_workers=2, shuffle=False) 37 | 38 | imratio = traindSet.imratio 39 | 40 | # model 41 | set_all_seeds(SEED) 42 | model = DenseNet121(pretrained=False, last_activation=None, activations='relu', num_classes=1) 43 | model = model.cuda() 44 | 45 | 46 | # load pretrained model 47 | if True: 48 | PATH = './ce_pretrained_model.pth' 49 | state_dict = torch.load(PATH) 50 | state_dict.pop('classifier.weight', None) 51 | state_dict.pop('classifier.bias', None) 52 | model.load_state_dict(state_dict, strict=False) 53 | 54 | # define loss & optimizer 55 | loss_fn = AUCMLoss() 56 | optimizer = PESG(model, 57 | loss_fn=loss_fn, 58 | lr=lr, 59 | margin=margin, 60 | epoch_decay=epoch_decay, 61 | weight_decay=weight_decay) 62 | 63 | best_val_auc = 0 64 | for epoch in range(10): 65 | if epoch > 0: 66 | optimizer.update_regularizer(decay_factor=2) 67 | for idx, data in enumerate(trainloader): 68 | train_data, train_labels = data 69 | train_data, train_labels = train_data.cuda(), train_labels.cuda() 70 | y_pred = model(train_data) 71 | y_pred = torch.sigmoid(y_pred) 72 | loss = loss_fn(y_pred, train_labels) 73 | optimizer.zero_grad() 74 | loss.backward() 75 | optimizer.step() 76 | 77 | # validation 78 | if idx % 100 == 0: 79 | model.eval() 80 | with torch.no_grad(): 81 | test_pred = [] 82 | test_true = [] 83 | for jdx, data in enumerate(testloader): 84 | test_data, test_label = data 85 | test_data = test_data.cuda() 86 | y_pred = model(test_data) 87 | test_pred.append(y_pred.cpu().detach().numpy()) 88 | test_true.append(test_label.numpy()) 89 | 90 | test_true = np.concatenate(test_true) 91 | test_pred = np.concatenate(test_pred) 92 | val_auc = roc_auc_score(test_true, test_pred) 93 | model.train() 94 | 95 | if best_val_auc < val_auc: 96 | best_val_auc = val_auc 97 | torch.save(model.state_dict(), 'finetuned_model.pth') 98 | 99 | print ('Epoch=%s, BatchID=%s, Val_AUC=%.4f, lr=%.4f'%(epoch, idx, val_auc, optimizer.lr)) 100 | 101 | print ('Best Val_AUC is %.4f'%best_val_auc) -------------------------------------------------------------------------------- /metric/confidence_model/pretrain.py: -------------------------------------------------------------------------------- 1 | from libauc.losses import AUCMLoss, CrossEntropyLoss 2 | from libauc.optimizers import PESG, Adam 3 | from libauc.models import densenet121 as DenseNet121 4 | from libauc.datasets import CheXpert 5 | 6 | import torch 7 | from PIL import Image 8 | import numpy as np 9 | import torchvision.transforms as transforms 10 | from torch.utils.data import Dataset 11 | from sklearn.metrics import roc_auc_score 12 | 13 | 14 | def set_all_seeds(SEED): 15 | # REPRODUCIBILITY 16 | torch.manual_seed(SEED) 17 | np.random.seed(SEED) 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | 21 | # paramaters 22 | SEED = 123 23 | BATCH_SIZE = 128 24 | lr = 1e-4 25 | weight_decay = 1e-5 26 | 27 | # dataloader 28 | root = './CheXpert-v1.0' 29 | # Index: -1 denotes multi-label mode including 5 diseases 30 | traindSet = CheXpert(csv_path=root+'train.csv', image_root_path=root, use_upsampling=False, use_frontal=True, image_size=224, mode='train', class_index=-1) 31 | testSet = CheXpert(csv_path=root+'valid.csv', image_root_path=root, use_upsampling=False, use_frontal=True, image_size=224, mode='valid', class_index=-1) 32 | trainloader = torch.utils.data.DataLoader(traindSet, batch_size=BATCH_SIZE, num_workers=2, shuffle=True) 33 | testloader = torch.utils.data.DataLoader(testSet, batch_size=BATCH_SIZE, num_workers=2, shuffle=False) 34 | 35 | # model 36 | set_all_seeds(SEED) 37 | model = DenseNet121(pretrained=True, last_activation=None, activations='relu', num_classes=5) 38 | model = model.cuda() 39 | 40 | # define loss & optimizer 41 | CELoss = CrossEntropyLoss() 42 | optimizer = Adam(model.parameters(), lr=lr, weight_decay=weight_decay) 43 | 44 | # training 45 | best_val_auc = 0 46 | for epoch in range(10): 47 | for idx, data in enumerate(trainloader): 48 | train_data, train_labels = data 49 | train_data, train_labels = train_data.cuda(), train_labels.cuda() 50 | y_pred = model(train_data) 51 | loss = CELoss(y_pred, train_labels) 52 | optimizer.zero_grad() 53 | loss.backward() 54 | optimizer.step() 55 | 56 | # validation 57 | if idx % 100 == 0: 58 | model.eval() 59 | with torch.no_grad(): 60 | test_pred = [] 61 | test_true = [] 62 | for jdx, data in enumerate(testloader): 63 | test_data, test_labels = data 64 | test_data = test_data.cuda() 65 | y_pred = model(test_data) 66 | test_pred.append(y_pred.cpu().detach().numpy()) 67 | test_true.append(test_labels.numpy()) 68 | 69 | test_true = np.concatenate(test_true) 70 | test_pred = np.concatenate(test_pred) 71 | val_auc_mean = roc_auc_score(test_true, test_pred) 72 | model.train() 73 | 74 | if best_val_auc < val_auc_mean: 75 | best_val_auc = val_auc_mean 76 | torch.save(model.state_dict(), 'ce_pretrained_model.pth') 77 | 78 | print ('Epoch=%s, BatchID=%s, Val_AUC=%.4f, Best_Val_AUC=%.4f'%(epoch, idx, val_auc_mean, best_val_auc )) -------------------------------------------------------------------------------- /pipeline_stable_diffusion_pie.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 Anonymous. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | from typing import Callable, List, Optional, Union 17 | import copy 18 | 19 | import numpy as np 20 | import cv2 21 | import PIL 22 | import torch 23 | from packaging import version 24 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 25 | 26 | from diffusers.configuration_utils import FrozenDict 27 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 28 | from diffusers.schedulers import KarrasDiffusionSchedulers 29 | from diffusers.utils import ( 30 | PIL_INTERPOLATION, 31 | deprecate, 32 | is_accelerate_available, 33 | is_accelerate_version, 34 | logging, 35 | randn_tensor, 36 | replace_example_docstring, 37 | ) 38 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 39 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 40 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 41 | 42 | 43 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 44 | 45 | EXAMPLE_DOC_STRING = """ 46 | Examples: 47 | ```py 48 | >>> import requests 49 | >>> import torch 50 | >>> from PIL import Image 51 | >>> from io import BytesIO 52 | 53 | >>> from diffusers import StableDiffusionImg2ImgPipeline 54 | 55 | >>> device = "cuda" 56 | >>> model_id_or_path = "runwayml/stable-diffusion-v1-5" 57 | >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) 58 | >>> pipe = pipe.to(device) 59 | 60 | >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" 61 | 62 | >>> response = requests.get(url) 63 | >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") 64 | >>> init_image = init_image.resize((768, 512)) 65 | 66 | >>> prompt = "A fantasy landscape, trending on artstation" 67 | 68 | >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images 69 | >>> images[0].save("fantasy_landscape.png") 70 | ``` 71 | """ 72 | 73 | def preprocess(image): 74 | if isinstance(image, torch.Tensor): 75 | return image 76 | elif isinstance(image, PIL.Image.Image): 77 | image = [image] 78 | 79 | if isinstance(image[0], PIL.Image.Image): 80 | w, h = image[0].size 81 | w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 82 | 83 | image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] 84 | image = np.concatenate(image, axis=0) 85 | image = np.array(image).astype(np.float32) / 255.0 86 | image = image.transpose(0, 3, 1, 2) 87 | image = 2.0 * image - 1.0 88 | image = torch.from_numpy(image) 89 | elif isinstance(image[0], torch.Tensor): 90 | image = torch.cat(image, dim=0) 91 | return image 92 | 93 | 94 | def preprocess_mask(image): 95 | if isinstance(image, torch.Tensor): 96 | return image 97 | elif isinstance(image, PIL.Image.Image): 98 | image = [image] 99 | 100 | if isinstance(image[0], PIL.Image.Image): 101 | w, h = image[0].size 102 | w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 103 | 104 | image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] 105 | 106 | image2 = [] 107 | for i in image: 108 | tmp = cv2.resize(i[0], (w // 8, h // 8), interpolation = cv2.INTER_LINEAR) 109 | tmp = np.expand_dims(tmp, axis=0) 110 | additional_channel = np.expand_dims(tmp[:,:,:,0], axis=-1) 111 | tmp = [tmp, additional_channel] 112 | tmp = np.concatenate(tmp, axis=-1) 113 | image2.append(tmp) 114 | 115 | image = np.concatenate(image2, axis=0) 116 | image = np.array(image).astype(np.float32) / 255.0 117 | image = image.transpose(0, 3, 1, 2) 118 | # image = 2.0 * image - 1.0 119 | image = torch.from_numpy(image) 120 | elif isinstance(image[0], torch.Tensor): 121 | image = torch.cat(image, dim=0) 122 | return image 123 | 124 | 125 | class StableDiffusionPIEPipeline(DiffusionPipeline): 126 | r""" 127 | Pipeline for Progressive Image Editing (PIE) disease progression generation using Stable Diffusion. 128 | 129 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 130 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 131 | 132 | Args: 133 | vae ([`AutoencoderKL`]): 134 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 135 | text_encoder ([`CLIPTextModel`]): 136 | Frozen text-encoder. Stable Diffusion uses the text portion of 137 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 138 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 139 | tokenizer (`CLIPTokenizer`): 140 | Tokenizer of class 141 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 142 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 143 | scheduler ([`SchedulerMixin`]): 144 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 145 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 146 | safety_checker ([`StableDiffusionSafetyChecker`]): 147 | Classification module that estimates whether generated images could be considered offensive or harmful. 148 | Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. 149 | feature_extractor ([`CLIPFeatureExtractor`]): 150 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 151 | """ 152 | _optional_components = ["safety_checker", "feature_extractor"] 153 | 154 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ 155 | def __init__( 156 | self, 157 | vae: AutoencoderKL, 158 | text_encoder: CLIPTextModel, 159 | tokenizer: CLIPTokenizer, 160 | unet: UNet2DConditionModel, 161 | scheduler: KarrasDiffusionSchedulers, 162 | safety_checker: StableDiffusionSafetyChecker, 163 | feature_extractor: CLIPFeatureExtractor, 164 | requires_safety_checker: bool = True, 165 | ): 166 | super().__init__() 167 | 168 | if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: 169 | deprecation_message = ( 170 | f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" 171 | f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " 172 | "to update the config accordingly as leaving `steps_offset` might led to incorrect results" 173 | " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," 174 | " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" 175 | " file" 176 | ) 177 | deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) 178 | new_config = dict(scheduler.config) 179 | new_config["steps_offset"] = 1 180 | scheduler._internal_dict = FrozenDict(new_config) 181 | 182 | if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: 183 | deprecation_message = ( 184 | f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." 185 | " `clip_sample` should be set to False in the configuration file. Please make sure to update the" 186 | " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" 187 | " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" 188 | " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" 189 | ) 190 | deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) 191 | new_config = dict(scheduler.config) 192 | new_config["clip_sample"] = False 193 | scheduler._internal_dict = FrozenDict(new_config) 194 | 195 | if safety_checker is None and requires_safety_checker: 196 | logger.warning( 197 | f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" 198 | " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" 199 | " results in services or applications open to the public. Both the diffusers team and Hugging Face" 200 | " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" 201 | " it only for use-cases that involve analyzing network behavior or auditing its results. For more" 202 | " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." 203 | ) 204 | 205 | if safety_checker is not None and feature_extractor is None: 206 | raise ValueError( 207 | "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" 208 | " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." 209 | ) 210 | 211 | is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( 212 | version.parse(unet.config._diffusers_version).base_version 213 | ) < version.parse("0.9.0.dev0") 214 | is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 215 | if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: 216 | deprecation_message = ( 217 | "The configuration file of the unet has set the default `sample_size` to smaller than" 218 | " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" 219 | " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" 220 | " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" 221 | " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" 222 | " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" 223 | " in the config might lead to incorrect results in future versions. If you have downloaded this" 224 | " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" 225 | " the `unet/config.json` file" 226 | ) 227 | deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) 228 | new_config = dict(unet.config) 229 | new_config["sample_size"] = 64 230 | unet._internal_dict = FrozenDict(new_config) 231 | 232 | self.register_modules( 233 | vae=vae, 234 | text_encoder=text_encoder, 235 | tokenizer=tokenizer, 236 | unet=unet, 237 | scheduler=scheduler, 238 | safety_checker=safety_checker, 239 | feature_extractor=feature_extractor, 240 | ) 241 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 242 | self.register_to_config(requires_safety_checker=requires_safety_checker) 243 | 244 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload 245 | def enable_sequential_cpu_offload(self, gpu_id=0): 246 | r""" 247 | Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, 248 | text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a 249 | `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. 250 | Note that offloading happens on a submodule basis. Memory savings are higher than with 251 | `enable_model_cpu_offload`, but performance is lower. 252 | """ 253 | if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): 254 | from accelerate import cpu_offload 255 | else: 256 | raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") 257 | 258 | device = torch.device(f"cuda:{gpu_id}") 259 | 260 | if self.device.type != "cpu": 261 | self.to("cpu", silence_dtype_warnings=True) 262 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) 263 | 264 | for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: 265 | cpu_offload(cpu_offloaded_model, device) 266 | 267 | if self.safety_checker is not None: 268 | cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) 269 | 270 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload 271 | def enable_model_cpu_offload(self, gpu_id=0): 272 | r""" 273 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared 274 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` 275 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with 276 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. 277 | """ 278 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): 279 | from accelerate import cpu_offload_with_hook 280 | else: 281 | raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") 282 | 283 | device = torch.device(f"cuda:{gpu_id}") 284 | 285 | if self.device.type != "cpu": 286 | self.to("cpu", silence_dtype_warnings=True) 287 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) 288 | 289 | hook = None 290 | for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: 291 | _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) 292 | 293 | if self.safety_checker is not None: 294 | _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) 295 | 296 | # We'll offload the last model manually. 297 | self.final_offload_hook = hook 298 | 299 | @property 300 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device 301 | def _execution_device(self): 302 | r""" 303 | Returns the device on which the pipeline's models will be executed. After calling 304 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module 305 | hooks. 306 | """ 307 | if not hasattr(self.unet, "_hf_hook"): 308 | return self.device 309 | for module in self.unet.modules(): 310 | if ( 311 | hasattr(module, "_hf_hook") 312 | and hasattr(module._hf_hook, "execution_device") 313 | and module._hf_hook.execution_device is not None 314 | ): 315 | return torch.device(module._hf_hook.execution_device) 316 | return self.device 317 | 318 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt 319 | def _encode_prompt( 320 | self, 321 | prompt, 322 | device, 323 | num_images_per_prompt, 324 | do_classifier_free_guidance, 325 | negative_prompt=None, 326 | prompt_embeds: Optional[torch.FloatTensor] = None, 327 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 328 | ): 329 | r""" 330 | Encodes the prompt into text encoder hidden states. 331 | 332 | Args: 333 | prompt (`str` or `List[str]`, *optional*): 334 | prompt to be encoded 335 | device: (`torch.device`): 336 | torch device 337 | num_images_per_prompt (`int`): 338 | number of images that should be generated per prompt 339 | do_classifier_free_guidance (`bool`): 340 | whether to use classifier free guidance or not 341 | negative_prompt (`str` or `List[str]`, *optional*): 342 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 343 | `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. 344 | Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). 345 | prompt_embeds (`torch.FloatTensor`, *optional*): 346 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 347 | provided, text embeddings will be generated from `prompt` input argument. 348 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 349 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 350 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 351 | argument. 352 | """ 353 | if prompt is not None and isinstance(prompt, str): 354 | batch_size = 1 355 | elif prompt is not None and isinstance(prompt, list): 356 | batch_size = len(prompt) 357 | else: 358 | batch_size = prompt_embeds.shape[0] 359 | 360 | if prompt_embeds is None: 361 | text_inputs = self.tokenizer( 362 | prompt, 363 | padding="max_length", 364 | max_length=self.tokenizer.model_max_length, 365 | truncation=True, 366 | return_tensors="pt", 367 | ) 368 | text_input_ids = text_inputs.input_ids 369 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 370 | 371 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 372 | text_input_ids, untruncated_ids 373 | ): 374 | removed_text = self.tokenizer.batch_decode( 375 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 376 | ) 377 | logger.warning( 378 | "The following part of your input was truncated because CLIP can only handle sequences up to" 379 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 380 | ) 381 | 382 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 383 | attention_mask = text_inputs.attention_mask.to(device) 384 | else: 385 | attention_mask = None 386 | 387 | prompt_embeds = self.text_encoder( 388 | text_input_ids.to(device), 389 | attention_mask=attention_mask, 390 | ) 391 | prompt_embeds = prompt_embeds[0] 392 | 393 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 394 | 395 | bs_embed, seq_len, _ = prompt_embeds.shape 396 | # duplicate text embeddings for each generation per prompt, using mps friendly method 397 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 398 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 399 | 400 | # get unconditional embeddings for classifier free guidance 401 | if do_classifier_free_guidance and negative_prompt_embeds is None: 402 | uncond_tokens: List[str] 403 | if negative_prompt is None: 404 | uncond_tokens = [""] * batch_size 405 | elif type(prompt) is not type(negative_prompt): 406 | raise TypeError( 407 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 408 | f" {type(prompt)}." 409 | ) 410 | elif isinstance(negative_prompt, str): 411 | uncond_tokens = [negative_prompt] 412 | elif batch_size != len(negative_prompt): 413 | raise ValueError( 414 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 415 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 416 | " the batch size of `prompt`." 417 | ) 418 | else: 419 | uncond_tokens = negative_prompt 420 | 421 | max_length = prompt_embeds.shape[1] 422 | uncond_input = self.tokenizer( 423 | uncond_tokens, 424 | padding="max_length", 425 | max_length=max_length, 426 | truncation=True, 427 | return_tensors="pt", 428 | ) 429 | 430 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 431 | attention_mask = uncond_input.attention_mask.to(device) 432 | else: 433 | attention_mask = None 434 | 435 | negative_prompt_embeds = self.text_encoder( 436 | uncond_input.input_ids.to(device), 437 | attention_mask=attention_mask, 438 | ) 439 | negative_prompt_embeds = negative_prompt_embeds[0] 440 | 441 | if do_classifier_free_guidance: 442 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 443 | seq_len = negative_prompt_embeds.shape[1] 444 | 445 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 446 | 447 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 448 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 449 | 450 | # For classifier free guidance, we need to do two forward passes. 451 | # Here we concatenate the unconditional and text embeddings into a single batch 452 | # to avoid doing two forward passes 453 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 454 | 455 | return prompt_embeds 456 | 457 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker 458 | def run_safety_checker(self, image, device, dtype): 459 | if self.safety_checker is not None: 460 | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) 461 | image, has_nsfw_concept = self.safety_checker( 462 | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) 463 | ) 464 | else: 465 | has_nsfw_concept = None 466 | return image, has_nsfw_concept 467 | 468 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents 469 | def decode_latents(self, latents): 470 | latents = 1 / self.vae.config.scaling_factor * latents 471 | image = self.vae.decode(latents).sample 472 | image = (image / 2 + 0.5).clamp(0, 1) 473 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 474 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 475 | return image 476 | 477 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 478 | def prepare_extra_step_kwargs(self, generator, eta): 479 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 480 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 481 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 482 | # and should be between [0, 1] 483 | 484 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 485 | extra_step_kwargs = {} 486 | if accepts_eta: 487 | extra_step_kwargs["eta"] = eta 488 | 489 | # check if the scheduler accepts generator 490 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 491 | if accepts_generator: 492 | extra_step_kwargs["generator"] = generator 493 | return extra_step_kwargs 494 | 495 | def check_inputs( 496 | self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None 497 | ): 498 | if strength < 0 or strength > 1: 499 | raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") 500 | 501 | if (callback_steps is None) or ( 502 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 503 | ): 504 | raise ValueError( 505 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 506 | f" {type(callback_steps)}." 507 | ) 508 | 509 | if prompt is not None and prompt_embeds is not None: 510 | raise ValueError( 511 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 512 | " only forward one of the two." 513 | ) 514 | elif prompt is None and prompt_embeds is None: 515 | raise ValueError( 516 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 517 | ) 518 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 519 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 520 | 521 | if negative_prompt is not None and negative_prompt_embeds is not None: 522 | raise ValueError( 523 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 524 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 525 | ) 526 | 527 | if prompt_embeds is not None and negative_prompt_embeds is not None: 528 | if prompt_embeds.shape != negative_prompt_embeds.shape: 529 | raise ValueError( 530 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 531 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 532 | f" {negative_prompt_embeds.shape}." 533 | ) 534 | 535 | def get_timesteps(self, num_inference_steps, strength, device): 536 | # get the original timestep using init_timestep 537 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 538 | 539 | t_start = max(num_inference_steps - init_timestep, 0) 540 | timesteps = self.scheduler.timesteps[t_start:] 541 | 542 | return timesteps, num_inference_steps - t_start 543 | 544 | def prepare_latents(self, image, mask, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): 545 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 546 | raise ValueError( 547 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 548 | ) 549 | 550 | image = image.to(device=device, dtype=dtype) 551 | 552 | batch_size = batch_size * num_images_per_prompt 553 | if isinstance(generator, list) and len(generator) != batch_size: 554 | raise ValueError( 555 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 556 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 557 | ) 558 | 559 | if isinstance(generator, list): 560 | init_latents = [ 561 | self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) 562 | ] 563 | init_latents = torch.cat(init_latents, dim=0) 564 | else: 565 | init_latents = self.vae.encode(image).latent_dist.sample(generator) 566 | 567 | init_latents = self.vae.config.scaling_factor * init_latents 568 | pre_latents = copy.deepcopy(init_latents) 569 | 570 | if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: 571 | # expand init_latents for batch_size 572 | deprecation_message = ( 573 | f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" 574 | " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" 575 | " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" 576 | " your script to pass as many initial images as text prompts to suppress this warning." 577 | ) 578 | deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) 579 | additional_image_per_prompt = batch_size // init_latents.shape[0] 580 | init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) 581 | elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: 582 | raise ValueError( 583 | f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 584 | ) 585 | else: 586 | init_latents = torch.cat([init_latents], dim=0) 587 | 588 | shape = init_latents.shape 589 | 590 | if mask == None: 591 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 592 | else: 593 | mask = mask.to(device=device, dtype=dtype) 594 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) * mask 595 | 596 | # get latents 597 | init_latents = self.scheduler.add_noise(init_latents, noise, timestep) 598 | latents = init_latents 599 | 600 | return latents, pre_latents 601 | 602 | @torch.no_grad() 603 | @replace_example_docstring(EXAMPLE_DOC_STRING) 604 | def __call__( 605 | self, 606 | prompt: Union[str, List[str]] = None, 607 | image: Union[torch.FloatTensor, PIL.Image.Image] = None, 608 | mask: Union[torch.FloatTensor, PIL.Image.Image] = None, 609 | init_image: Union[torch.FloatTensor, PIL.Image.Image] = None, 610 | strength: float = 0.5, 611 | num_inference_steps: Optional[int] = 50, 612 | beta_1: float = 0.01, 613 | beta_2: float = 0.75, 614 | guidance_scale: Optional[float] = 7.5, 615 | negative_prompt: Optional[Union[str, List[str]]] = None, 616 | num_images_per_prompt: Optional[int] = 1, 617 | eta: Optional[float] = 0.0, 618 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 619 | prompt_embeds: Optional[torch.FloatTensor] = None, 620 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 621 | output_type: Optional[str] = "pil", 622 | return_dict: bool = True, 623 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 624 | callback_steps: int = 1, 625 | ): 626 | r""" 627 | Function invoked when calling the pipeline for generation. 628 | 629 | Args: 630 | prompt (`str` or `List[str]`, *optional*): 631 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 632 | instead. 633 | image (`torch.FloatTensor` or `PIL.Image.Image`): 634 | `Image`, or tensor representing an image batch, that will be used as the starting point for the 635 | process. 636 | strength (`float`, *optional*, defaults to 0.8): 637 | Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. `image` 638 | will be used as a starting point, adding more noise to it the larger the `strength`. The number of 639 | denoising steps depends on the amount of noise initially added. When `strength` is 1, added noise will 640 | be maximum and the denoising process will run for the full number of iterations specified in 641 | `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. 642 | num_inference_steps (`int`, *optional*, defaults to 50): 643 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 644 | expense of slower inference. This parameter will be modulated by `strength`. 645 | beta 1 (`float`, *optional*, defaults to 0.01): 646 | indicates how much to control the smooth and realism of output edited images. 647 | beta 2 (`float`, *optional*, defaults to 0.75): 648 | indicates how much to control the smooth and realism of output edited images. 649 | guidance_scale (`float`, *optional*, defaults to 7.5): 650 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 651 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 652 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 653 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 654 | usually at the expense of lower image quality. 655 | negative_prompt (`str` or `List[str]`, *optional*): 656 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 657 | `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` 658 | is less than `1`). 659 | num_images_per_prompt (`int`, *optional*, defaults to 1): 660 | The number of images to generate per prompt. 661 | eta (`float`, *optional*, defaults to 0.0): 662 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 663 | [`schedulers.DDIMScheduler`], will be ignored for others. 664 | generator (`torch.Generator`, *optional*): 665 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 666 | to make generation deterministic. 667 | prompt_embeds (`torch.FloatTensor`, *optional*): 668 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 669 | provided, text embeddings will be generated from `prompt` input argument. 670 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 671 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 672 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 673 | argument. 674 | output_type (`str`, *optional*, defaults to `"pil"`): 675 | The output format of the generate image. Choose between 676 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 677 | return_dict (`bool`, *optional*, defaults to `True`): 678 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 679 | plain tuple. 680 | callback (`Callable`, *optional*): 681 | A function that will be called every `callback_steps` steps during inference. The function will be 682 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 683 | callback_steps (`int`, *optional*, defaults to 1): 684 | The frequency at which the `callback` function will be called. If not specified, the callback will be 685 | called at every step. 686 | Examples: 687 | 688 | Returns: 689 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 690 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 691 | When returning a tuple, the first element is a list with the generated images, and the second element is a 692 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 693 | (nsfw) content, according to the `safety_checker`. 694 | """ 695 | # 1. Check inputs. Raise error if not correct 696 | self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) 697 | 698 | # 2. Define call parameters 699 | if prompt is not None and isinstance(prompt, str): 700 | batch_size = 1 701 | elif prompt is not None and isinstance(prompt, list): 702 | batch_size = len(prompt) 703 | else: 704 | batch_size = prompt_embeds.shape[0] 705 | device = self._execution_device 706 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 707 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 708 | # corresponds to doing no classifier free guidance. 709 | do_classifier_free_guidance = guidance_scale > 1.0 710 | 711 | # 3. Encode input prompt 712 | prompt_embeds = self._encode_prompt( 713 | prompt, 714 | device, 715 | num_images_per_prompt, 716 | do_classifier_free_guidance, 717 | negative_prompt, 718 | prompt_embeds=prompt_embeds, 719 | negative_prompt_embeds=negative_prompt_embeds, 720 | ) 721 | 722 | # 4. Preprocess previous step image, the first step image, and the mask. 723 | image = preprocess(image) 724 | init_image = preprocess(init_image) 725 | if mask != None: 726 | mask = preprocess_mask(mask) 727 | 728 | # 5. set timesteps 729 | self.scheduler.set_timesteps(num_inference_steps, device=device) 730 | timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) 731 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 732 | 733 | # 6. Prepare latent variables 734 | latents, pre_latents = self.prepare_latents( 735 | image, None, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator 736 | ) 737 | _, init_latents = self.prepare_latents( 738 | init_image, None, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator 739 | ) 740 | 741 | inter_latents, _ = self.prepare_latents( 742 | init_image, None, timesteps, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator 743 | ) 744 | 745 | if mask != None: 746 | mask = mask.to(device=device, dtype=prompt_embeds.dtype) 747 | 748 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 749 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 750 | 751 | # 8. Denoising loop 752 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 753 | with self.progress_bar(total=num_inference_steps) as progress_bar: 754 | for i, t in enumerate(timesteps): 755 | 756 | if mask != None and 0 < i < len(timesteps - 1): 757 | latents = inter_latents[i:i+1] * (1 - mask) + latents * (mask) 758 | # ((latents - init_latents) * beta + init_latents) * (1 - mask) + ((latents - init_latents) * alpha + init_latents) * mask 759 | 760 | # expand the latents if we are doing classifier free guidance 761 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 762 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 763 | 764 | # predict the noise residual 765 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample 766 | # noise_pred = noise_pred * mask 767 | 768 | # perform guidance 769 | if do_classifier_free_guidance: 770 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 771 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 772 | 773 | # compute the previous noisy sample x_t -> x_t-1 774 | # pre_latents = copy.deepcopy(latents) 775 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 776 | # latents = latents * (mask) + pre_latents * (1 - mask) 777 | 778 | # call the callback, if provided 779 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 780 | progress_bar.update() 781 | if callback is not None and i % callback_steps == 0: 782 | callback(i, t, latents) 783 | 784 | # 9. Post-processing 785 | if mask != None: 786 | if init_latents != None: 787 | latents = ((latents - init_latents) * beta_1 + init_latents) * (1 - mask) + ((latents - init_latents) * beta_2 + init_latents) * mask 788 | image = self.decode_latents(latents) 789 | else: 790 | latents = ((latents - pre_latents) * beta_1 + pre_latents) * (1 - mask) + ((latents - pre_latents) * beta_2 + pre_latents) * mask 791 | image = self.decode_latents(latents) 792 | else: 793 | image = self.decode_latents(latents) 794 | 795 | # 10. Run safety checker 796 | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) 797 | 798 | # 11. Convert to PIL 799 | if output_type == "pil": 800 | image = self.numpy_to_pil(image) 801 | 802 | # Offload last model to CPU 803 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 804 | self.final_offload_hook.offload() 805 | 806 | if not return_dict: 807 | return (image, has_nsfw_concept) 808 | 809 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) -------------------------------------------------------------------------------- /plot_confidence.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | from PIL import Image 5 | import cv2 6 | import torch 7 | import torch.nn as nn 8 | from torch.utils.data import Dataset 9 | from sklearn.metrics import roc_auc_score 10 | import matplotlib.pyplot as plt 11 | 12 | from libauc.datasets import CheXpert 13 | from libauc.models import densenet121 as DenseNet121 14 | 15 | def parse_args(input_args=None): 16 | parser = argparse.ArgumentParser(description="Plot the classification confidence score results") 17 | parser.add_argument( 18 | "--model_path", 19 | type=str, 20 | default=None, 21 | required=True, 22 | help="Path to confidence model.", 23 | ) 24 | parser.add_argument( 25 | "--image_dir", 26 | type=str, 27 | default=None, 28 | required=True, 29 | help="Path to input image sequence folder", 30 | ) 31 | parser.add_argument( 32 | "--plot_path", 33 | type=str, 34 | default="./plot.png", 35 | help="The output path.", 36 | ) 37 | parser.add_argument( 38 | "--num", 39 | type=int, 40 | default=10, 41 | help="The number of images to draw the plot.", 42 | ) 43 | parser.add_argument( 44 | "--resolution", 45 | type=int, 46 | default=224, 47 | help=( 48 | "The resolution for input size for confidence model" 49 | ), 50 | ) 51 | args = parser.parse_args() 52 | return args 53 | 54 | def main(args): 55 | IMG_PATH = args.image_dir 56 | count = args.num 57 | image_size = args.resolution 58 | sigmoid = nn.Sigmoid() 59 | 60 | model = DenseNet121(pretrained=False, last_activation=None, activations='relu', num_classes=1) 61 | checkpoint = torch.load(args.model_path) 62 | model.load_state_dict(checkpoint) 63 | model = model.cuda() 64 | model.eval() 65 | 66 | index = [] 67 | value = [] 68 | 69 | for i in range(0, count+1): 70 | image_path = os.path.join(IMG_PATH, str(i) + ".png") 71 | image = cv2.imread(image_path, 0) 72 | 73 | image = Image.fromarray(image) 74 | image = np.array(image) 75 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 76 | 77 | image = cv2.resize(image, dsize=(image_size, image_size), interpolation=cv2.INTER_LINEAR) 78 | image = image / 255.0 79 | mean = np.array([[[0.485, 0.456, 0.406]]]) 80 | std = np.array([[[0.229, 0.224, 0.225]]]) 81 | image = (image-mean)/std 82 | image = image.transpose((2, 0, 1)).astype(np.float32) 83 | image = torch.from_numpy(image) 84 | image = image.unsqueeze(0) 85 | image = image.cuda() 86 | 87 | with torch.no_grad(): 88 | index.append(i) 89 | value.append(sigmoid(model(image)).cpu().numpy()[0][0]) 90 | 91 | plt.plot(index, value) 92 | plt.savefig(args.plot_path) 93 | 94 | if __name__ == "__main__": 95 | args = parse_args() 96 | main(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | matplotlib 3 | transformers==4.26.1 4 | tokenizers==0.13.2 5 | huggingface-hub==0.13.2 6 | diffusers==0.15.0 7 | clip==0.2.0 8 | libauc==1.2.0 9 | openai -------------------------------------------------------------------------------- /run_pie.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | from PIL import Image 5 | import numpy as np 6 | from torchvision import transforms 7 | 8 | from pipeline_stable_diffusion_pie import StableDiffusionPIEPipeline 9 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel 10 | 11 | def set_all_seeds(SEED): 12 | # REPRODUCIBILITY 13 | torch.manual_seed(SEED) 14 | np.random.seed(SEED) 15 | torch.backends.cudnn.deterministic = True 16 | torch.backends.cudnn.benchmark = False 17 | 18 | def parse_args(input_args=None): 19 | parser = argparse.ArgumentParser(description="Simple example of a PIE inference script.") 20 | parser.add_argument( 21 | "--pretrained_model_name_or_path", 22 | type=str, 23 | default="runwayml/stable-diffusion-v1-5", 24 | required=True, 25 | help="Path to pretrained model or model identifier from huggingface.co/models.", 26 | ) 27 | parser.add_argument( 28 | "--revision", 29 | type=str, 30 | default=None, 31 | required=False, 32 | help=( 33 | "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" 34 | " float32 precision." 35 | ), 36 | ) 37 | parser.add_argument( 38 | "--finetuned_path", 39 | type=str, 40 | default=None, 41 | required=False, 42 | help="Path to domain specific finetuned unet from any healthcare text-to-image dataset", 43 | ) 44 | parser.add_argument( 45 | "--image_path", 46 | type=str, 47 | default=None, 48 | required=True, 49 | help="Path to the input instance images.", 50 | ) 51 | parser.add_argument( 52 | "--mask_path", 53 | type=str, 54 | default=None, 55 | required=False, 56 | help="Path to mask.", 57 | ) 58 | parser.add_argument( 59 | "--prompt", 60 | type=str, 61 | default=None, 62 | required=True, 63 | help="The prompt with identifier specifying the instance", 64 | ) 65 | parser.add_argument("--step", type=int, default=10, help="N in the paper, Number to images / steps for PIE generation") 66 | parser.add_argument("--strength", type=float, default=0.5, help="Roll back ratio garmma") 67 | parser.add_argument("--guidance_scale", type=float, default=7.5, help="guidance scale") 68 | parser.add_argument( 69 | "--output_dir", 70 | type=str, 71 | default="./simulation", 72 | help="The output directory.", 73 | ) 74 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 75 | parser.add_argument( 76 | "--resolution", 77 | type=int, 78 | default=512, 79 | help=( 80 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 81 | " resolution" 82 | ), 83 | ) 84 | 85 | args = parser.parse_args() 86 | return args 87 | 88 | 89 | def main(args): 90 | seed = args.seed 91 | set_all_seeds(seed) 92 | 93 | image_path = args.image_path 94 | mask_path = args.mask_path 95 | prompt = args.prompt 96 | 97 | model_id_or_path = args.pretrained_model_name_or_path 98 | finetuned_path = args.finetuned_path 99 | resolution = args.resolution 100 | ddim_times = args.step 101 | strength = args.strength 102 | guidance_scale = args.guidance_scale 103 | 104 | output_dir = args.output_dir 105 | 106 | if not os.path.exists(output_dir): 107 | os.mkdir(output_dir) 108 | 109 | device = "cuda" 110 | pipe = StableDiffusionPIEPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float32, cache_dir="./checkpoints", safety_checker=None) 111 | if finetuned_path != None: 112 | unet = UNet2DConditionModel.from_pretrained( 113 | finetuned_path, subfolder="text_encoder" 114 | ) 115 | pipe.unet = unet 116 | pipe = pipe.to(device) 117 | 118 | image_transforms = transforms.Compose( 119 | [ 120 | transforms.Resize(resolution, interpolation=transforms.InterpolationMode.BILINEAR), 121 | transforms.CenterCrop(resolution) 122 | ] 123 | ) 124 | 125 | images = [] 126 | step_i = 0 127 | init_image = Image.open(image_path).convert("RGB") # The unedited image 128 | init_image = image_transforms(init_image) 129 | init_image.save(os.path.join(output_dir, str(step_i) + ".png")) 130 | 131 | if mask_path != None: 132 | mask = Image.open(mask_path).convert("RGB") 133 | mask = image_transforms(mask) 134 | mask.save(os.path.join(output_dir, "mask" + ".png")) 135 | else: 136 | mask = None 137 | 138 | step_i += 1 139 | img = init_image 140 | images.append(img) 141 | 142 | while step_i <= ddim_times: 143 | img = pipe(prompt=prompt, image=img, mask=mask, init_image=init_image, strength=strength, guidance_scale=guidance_scale).images[0] 144 | images.append(img) 145 | img.save(os.path.join(output_dir, str(step_i) + ".png")) 146 | step_i += 1 147 | 148 | duration = 1000 149 | images[0].save('output.gif', save_all=True, append_images=images[1:], duration=duration) 150 | 151 | if __name__ == "__main__": 152 | args = parse_args() 153 | main(args) 154 | -------------------------------------------------------------------------------- /training/README.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | Finetuning stable diffusion base model on each of the three medical imaging dataset 4 | 5 | 6 | ## Usage 7 | Script for finetuning on Chexpert: 8 | ```bash 9 | #!/bin/bash 10 | source ~/.bashrc 11 | conda activate diffusion 12 | export HF_HOME=./cache/ 13 | torchrun --nproc_per_node=8 train.py \ 14 | --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ 15 | --instance_data_dir="" \ 16 | --output_dir="" \ 17 | --instance_prompt="A chest X-ray image" \ 18 | --resolution=512 \ 19 | --train_batch_size=8 \ 20 | --gradient_accumulation_steps=2 \ 21 | --learning_rate=5e-5 \ 22 | --lr_warmup_steps=1000 \ 23 | --max_train_steps=20000 \ 24 | --lr_scheduler "cosine" \ 25 | --checkpoints_total_limit 2 \ 26 | --gradient_checkpointing \ 27 | --mixed_precision bf16 \ 28 | --center_crop \ 29 | --instance_dataset chexpert \ 30 | --checkpointing_steps 5000 31 | ``` 32 | Script for finetuning on Retinopathy: 33 | ```bash 34 | #!/bin/bash 35 | source ~/.bashrc 36 | conda activate diffusion 37 | echo $CUDA_VISIBLE_DEVICES 38 | export HF_HOME=./cache/ 39 | torchrun --nproc_per_node=8 --master_port 0 train_dreambooth.py \ 40 | --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ 41 | --instance_data_dir="" \ 42 | --output_dir="" \ 43 | --instance_prompt="A retinopathy image" \ 44 | --resolution=512 \ 45 | --train_batch_size=8 \ 46 | --gradient_accumulation_steps=2 \ 47 | --learning_rate=5e-5 \ 48 | --lr_warmup_steps=1000 \ 49 | --max_train_steps=20000 \ 50 | --lr_scheduler "cosine" \ 51 | --checkpoints_total_limit 2 \ 52 | --gradient_checkpointing \ 53 | --mixed_precision bf16 \ 54 | --center_crop \ 55 | --instance_dataset retinopathy \ 56 | --checkpointing_steps 5000 57 | ``` 58 | 59 | Script for finetuning on ISIC: 60 | ```bash 61 | #!/bin/bash 62 | source ~/.bashrc 63 | conda activate diffusion 64 | export HF_HOME=./cache/ 65 | torchrun --nproc_per_node=8 train_dreambooth.py \ 66 | --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \ 67 | --instance_data_dir="" \ 68 | --output_dir="" \ 69 | --instance_prompt="An image of skin" \ 70 | --resolution=512 \ 71 | --train_batch_size=8 \ 72 | --gradient_accumulation_steps=2 \ 73 | --learning_rate=5e-5 \ 74 | --lr_warmup_steps=1000 \ 75 | --max_train_steps=20000 \ 76 | --lr_scheduler "cosine" \ 77 | --checkpoints_total_limit 2 \ 78 | --gradient_checkpointing \ 79 | --mixed_precision bf16 \ 80 | --center_crop \ 81 | --instance_dataset isic \ 82 | --checkpointing_steps 5000 83 | ``` 84 | 85 | ## License 86 | 87 | This project is licensed under the [MIT License](LICENSE). 88 | -------------------------------------------------------------------------------- /training/chexpert.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import albumentations as A 4 | import cv2 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from torch.utils.data import Dataset 9 | from PIL import Image 10 | 11 | def cohen_aug(img): 12 | # Follow https://arxiv.org/pdf/2002.02497.pdf, page 4 13 | # "Data augmentation was used to improve generalization. According to best results inCohen et al. (2019) (and replicated by us) 14 | # each image was rotated up to 45 degrees, translatedup to 15% and scaled larger of smaller up to 10%" 15 | aug_ = A.Compose([ 16 | A.ShiftScaleRotate(p=1.0, shift_limit=0.25, rotate_limit=45, scale_limit=0.1), 17 | A.HorizontalFlip(p=0.5), 18 | ]) 19 | return aug_(image=img[0])["image"].reshape(img.shape) 20 | 21 | 22 | class ChexPert(Dataset): 23 | """ 24 | CheXpert: A Large Chest Radiograph Dataset with Uncertainty Labels and Expert Comparison. 25 | Jeremy Irvin *, Pranav Rajpurkar *, Michael Ko, Yifan Yu, Silviana Ciurea-Ilcus, Chris Chute, 26 | Henrik Marklund, Behzad Haghgoo, Robyn Ball, Katie Shpanskaya, Jayne Seekins, David A. Mong, 27 | Safwan S. Halabi, Jesse K. Sandberg, Ricky Jones, David B. Larson, Curtis P. Langlotz, 28 | Bhavik N. Patel, Matthew P. Lungren, Andrew Y. Ng. https://arxiv.org/abs/1901.07031 29 | 30 | Dataset website here: 31 | https://stanfordmlgroup.github.io/competitions/chexpert/ 32 | """ 33 | 34 | # def __init__(self, imgpath, csvpath, views=["PA"], transform=None, data_aug=None, 35 | # flat_dir=True, seed=0, unique_patients=True): 36 | def __init__(self, path, split="train", aug=None, transform=None, views=["AP", "PA"], unique_patients=False): 37 | super().__init__() 38 | if split == "train": 39 | csvpath = os.path.join(path, 'train.csv') 40 | elif split == "valid": 41 | csvpath = os.path.join(path, 'valid.csv') 42 | else: 43 | raise ValueError(csvpath) 44 | self.data_aug = aug 45 | # np.random.seed(seed) # Reset the seed so all runs are the same. 46 | self.MAXVAL = 255 47 | 48 | self.pathologies = [ 49 | "Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Opacity", "Lung Lesion", "Edema", "Consolidation", 50 | "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion", "Pleural Other", "Fracture", 51 | "Support Devices" 52 | ] 53 | 54 | self.pathologies = sorted(self.pathologies) 55 | 56 | self.imgpath = os.path.dirname(path) 57 | self.transform = transform 58 | self.data_aug = aug 59 | self.csvpath = csvpath 60 | self.csv = pd.read_csv(self.csvpath) 61 | self.views = views 62 | 63 | # To list 64 | if type(self.views) is not list: 65 | views = [views] 66 | self.views = views 67 | 68 | self.csv["view"] = self.csv["Frontal/Lateral"] # Assign view column 69 | self.csv.loc[(self.csv["view"] == "Frontal"), "view"] = self.csv[ 70 | "AP/PA"] # If Frontal change with the corresponding value in the AP/PA column otherwise remains Lateral 71 | self.csv["view"] = self.csv["view"].replace({'Lateral': "L"}) # Rename Lateral with L 72 | self.csv = self.csv[self.csv["view"].isin(self.views)] # Select the view 73 | 74 | if unique_patients: 75 | self.csv["PatientID"] = self.csv["Path"].str.extract(pat='(patient\d+)') 76 | self.csv = self.csv.groupby("PatientID").first().reset_index() 77 | 78 | # Get our classes. 79 | healthy = self.csv["No Finding"] == 1 80 | self.labels = [] 81 | for pathology in self.pathologies: 82 | if pathology in self.csv.columns: 83 | self.csv.loc[healthy, pathology] = 0 84 | mask = self.csv[pathology] 85 | 86 | self.labels.append(mask.values) 87 | self.labels = np.asarray(self.labels).T 88 | self.labels = self.labels.astype(np.float32) 89 | 90 | # make all the -1 values into nans to keep things simple 91 | self.labels[self.labels == -1] = np.nan 92 | 93 | # rename pathologies 94 | self.pathologies = list(np.char.replace(self.pathologies, "Pleural Effusion", "Effusion")) 95 | print(self.pathologies) 96 | 97 | ########## add consistent csv values 98 | 99 | # offset_day_int 100 | 101 | # patientid 102 | if 'train' in csvpath: 103 | patientid = self.csv.Path.str.split("train/", expand=True)[1] 104 | elif 'valid' in csvpath: 105 | patientid = self.csv.Path.str.split("valid/", expand=True)[1] 106 | else: 107 | raise NotImplemented 108 | 109 | patientid = patientid.str.split("/study", expand=True)[0] 110 | patientid = patientid.str.replace("patient", "") 111 | self.csv["patientid"] = patientid 112 | 113 | def string(self): 114 | return self.__class__.__name__ + " num_samples={} views={} data_aug={}".format( 115 | len(self), self.views, self.data_aug) 116 | 117 | def __len__(self): 118 | return len(self.labels) 119 | 120 | def __getitem__(self, idx): 121 | imgid = self.csv['Path'].iloc[idx] 122 | img_path = os.path.join(self.imgpath, imgid) 123 | img = Image.open(img_path) 124 | if not img.mode == "RGB": 125 | img = img.convert("RGB") 126 | 127 | if self.transform is not None: 128 | img = self.transform(img) 129 | 130 | if self.data_aug is not None: 131 | img = self.data_aug(img) 132 | 133 | target = self.labels[idx] 134 | pathologies = [] 135 | for i in range(len(target)): 136 | if target[i] == 1: 137 | pathologies.append(self.pathologies[i]) 138 | if len(pathologies) > 0: 139 | pathologies = ",".join(pathologies) 140 | else: 141 | pathologies = "" 142 | target = torch.from_numpy(target).float() 143 | metadata = {'img_path': img_path, "pathologies": pathologies} 144 | return img, target, metadata -------------------------------------------------------------------------------- /training/retinopahty.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | from torch.utils.data import Dataset 6 | from PIL import Image 7 | from torchvision import transforms 8 | 9 | 10 | class Retinopathy(Dataset): 11 | def __init__(self, path, split="train", transform=None): 12 | super().__init__() 13 | if split == "train": 14 | csvpath = os.path.join(path, 'trainLabels.csv') 15 | self.csvpath = csvpath 16 | self.imgpath = os.path.join(path, 'train') 17 | 18 | self.transform = transform 19 | self.csv = pd.read_csv(self.csvpath) 20 | self.pathologies = ["No DR", "Mild", "Moderate", "Severe", "Proliferative DR"] 21 | 22 | def __len__(self): 23 | return len(self.csv.index) 24 | 25 | def __getitem__(self, idx): 26 | imgid = self.csv['image'].iloc[idx] 27 | img_path = os.path.join(self.imgpath, imgid+'.jpeg') 28 | img = Image.open(img_path) 29 | if not img.mode == "RGB": 30 | img = img.convert("RGB") 31 | if self.transform is not None: 32 | img = self.transform(img) 33 | target = self.csv['level'].iloc[idx] 34 | if target == 0: 35 | metadata = {'img_path': img_path, "pathologies": "healthy"} 36 | else: 37 | metadata = {'img_path': img_path, "pathologies": "diabetic"} 38 | return img, target, metadata -------------------------------------------------------------------------------- /training/skin.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import albumentations as A 4 | import cv2 5 | import numpy as np 6 | import pandas as pd 7 | import torch 8 | from torch.utils.data import Dataset 9 | from PIL import Image 10 | 11 | def cohen_aug(img): 12 | # Follow https://arxiv.org/pdf/2002.02497.pdf, page 4 13 | # "Data augmentation was used to improve generalization. According to best results inCohen et al. (2019) (and replicated by us) 14 | # each image was rotated up to 45 degrees, translatedup to 15% and scaled larger of smaller up to 10%" 15 | aug_ = A.Compose([ 16 | A.ShiftScaleRotate(p=1.0, shift_limit=0.25, rotate_limit=45, scale_limit=0.1), 17 | A.HorizontalFlip(p=0.5), 18 | ]) 19 | return aug_(image=img[0])["image"].reshape(img.shape) 20 | 21 | 22 | class ISIC(Dataset): 23 | def __init__(self, path, aug=None, transform=None): 24 | super().__init__() 25 | self.data_path1 = os.path.join(path, "ham10000_images_part_1") 26 | self.data_path2 = os.path.join(path, "ham10000_images_part_2") 27 | self.csvpath = os.path.join(path, "HAM10000_metadata.csv") 28 | self.data_aug = aug 29 | self.MAXVAL = 255 30 | 31 | self.pathologies = { 32 | "bkl": "benign keratosis-like lesions", 33 | "mel": "melanoma", 34 | "nv": "melanocytic nevi", 35 | "akiec": "actinic keratoses and intraepithelial carcinoma", 36 | "bcc": "basal cell carcinoma", 37 | "vasc": "vascular lesions", 38 | "df": "dermatofibroma" 39 | } 40 | self.imgpath = os.path.dirname(path) 41 | self.transform = transform 42 | self.data_aug = aug 43 | self.csv = pd.read_csv(self.csvpath) 44 | 45 | def __len__(self): 46 | return len(self.csv["image_id"]) 47 | 48 | def __getitem__(self, idx): 49 | imgid = self.csv['image_id'].iloc[idx] 50 | img_path = os.path.join(self.data_path1, str(imgid)+".jpg") 51 | if not os.path.exists(img_path): 52 | img_path = os.path.join(self.data_path2, str(imgid)+".jpg") 53 | img = Image.open(img_path) 54 | if not img.mode == "RGB": 55 | img = img.convert("RGB") 56 | 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | if self.data_aug is not None: 60 | img = self.data_aug(img) 61 | metadata = {'img_path': img_path, "target": self.csv['dx'].iloc[idx] ,"pathologies": self.pathologies[self.csv['dx'].iloc[idx]]} 62 | return img, None, metadata -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import hashlib 18 | import itertools 19 | import logging 20 | import math 21 | import os 22 | import warnings 23 | from pathlib import Path 24 | 25 | import accelerate 26 | import numpy as np 27 | import torch 28 | import torch.nn.functional as F 29 | import torch.utils.checkpoint 30 | import transformers 31 | from accelerate import Accelerator 32 | from accelerate.logging import get_logger 33 | from accelerate.utils import ProjectConfiguration, set_seed 34 | from huggingface_hub import create_repo, upload_folder 35 | from packaging import version 36 | from PIL import Image 37 | from torch.utils.data import Dataset 38 | from torchvision import transforms 39 | from tqdm.auto import tqdm 40 | from transformers import AutoTokenizer, PretrainedConfig 41 | from PIL import Image 42 | 43 | import diffusers 44 | from diffusers import ( 45 | AutoencoderKL, 46 | DDPMScheduler, 47 | DiffusionPipeline, 48 | DPMSolverMultistepScheduler, 49 | UNet2DConditionModel, 50 | ) 51 | from diffusers.optimization import get_scheduler 52 | from diffusers.utils import check_min_version, is_wandb_available 53 | from diffusers.utils.import_utils import is_xformers_available 54 | from chexpert import ChexPert 55 | from retinopathy import Retinopathy 56 | from skin import ISIC 57 | 58 | 59 | if is_wandb_available(): 60 | import wandb 61 | 62 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 63 | check_min_version("0.15.0.dev0") 64 | 65 | logger = get_logger(__name__) 66 | 67 | 68 | def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch): 69 | logger.info( 70 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 71 | f" {args.validation_prompt}." 72 | ) 73 | # create pipeline (note: unet and vae are loaded again in float32) 74 | pipeline = DiffusionPipeline.from_pretrained( 75 | args.pretrained_model_name_or_path, 76 | text_encoder=accelerator.unwrap_model(text_encoder), 77 | tokenizer=tokenizer, 78 | unet=accelerator.unwrap_model(unet), 79 | vae=vae, 80 | revision=args.revision, 81 | torch_dtype=weight_dtype, 82 | ) 83 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 84 | pipeline = pipeline.to(accelerator.device) 85 | pipeline.set_progress_bar_config(disable=True) 86 | 87 | # run inference 88 | generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) 89 | images = [] 90 | for _ in range(args.num_validation_images): 91 | with torch.autocast("cuda"): 92 | image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] 93 | images.append(image) 94 | 95 | for tracker in accelerator.trackers: 96 | if tracker.name == "tensorboard": 97 | np_images = np.stack([np.asarray(img) for img in images]) 98 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") 99 | if tracker.name == "wandb": 100 | tracker.log( 101 | { 102 | "validation": [ 103 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) 104 | ] 105 | } 106 | ) 107 | 108 | del pipeline 109 | torch.cuda.empty_cache() 110 | 111 | 112 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): 113 | text_encoder_config = PretrainedConfig.from_pretrained( 114 | pretrained_model_name_or_path, 115 | subfolder="text_encoder", 116 | revision=revision, 117 | ) 118 | model_class = text_encoder_config.architectures[0] 119 | 120 | if model_class == "CLIPTextModel": 121 | from transformers import CLIPTextModel 122 | 123 | return CLIPTextModel 124 | elif model_class == "RobertaSeriesModelWithTransformation": 125 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation 126 | 127 | return RobertaSeriesModelWithTransformation 128 | else: 129 | raise ValueError(f"{model_class} is not supported.") 130 | 131 | 132 | def parse_args(input_args=None): 133 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 134 | parser.add_argument( 135 | "--pretrained_model_name_or_path", 136 | type=str, 137 | default=None, 138 | required=True, 139 | help="Path to pretrained model or model identifier from huggingface.co/models.", 140 | ) 141 | parser.add_argument( 142 | "--revision", 143 | type=str, 144 | default=None, 145 | required=False, 146 | help=( 147 | "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" 148 | " float32 precision." 149 | ), 150 | ) 151 | parser.add_argument( 152 | "--tokenizer_name", 153 | type=str, 154 | default=None, 155 | help="Pretrained tokenizer name or path if not the same as model_name", 156 | ) 157 | parser.add_argument( 158 | "--instance_data_dir", 159 | type=str, 160 | default=None, 161 | required=True, 162 | help="A folder containing the training data of instance images.", 163 | ) 164 | parser.add_argument( 165 | "--class_data_dir", 166 | type=str, 167 | default=None, 168 | required=False, 169 | help="A folder containing the training data of class images.", 170 | ) 171 | parser.add_argument( 172 | "--instance_prompt", 173 | type=str, 174 | default=None, 175 | required=True, 176 | help="The prompt with identifier specifying the instance", 177 | ) 178 | parser.add_argument( 179 | "--class_prompt", 180 | type=str, 181 | default=None, 182 | help="The prompt to specify images in the same class as provided instance images.", 183 | ) 184 | parser.add_argument( 185 | "--with_prior_preservation", 186 | default=False, 187 | action="store_true", 188 | help="Flag to add prior preservation loss.", 189 | ) 190 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") 191 | parser.add_argument( 192 | "--num_class_images", 193 | type=int, 194 | default=100, 195 | help=( 196 | "Minimal class images for prior preservation loss. If there are not enough images already present in" 197 | " class_data_dir, additional images will be sampled with class_prompt." 198 | ), 199 | ) 200 | parser.add_argument( 201 | "--output_dir", 202 | type=str, 203 | default="text-inversion-model", 204 | help="The output directory where the model predictions and checkpoints will be written.", 205 | ) 206 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 207 | parser.add_argument( 208 | "--resolution", 209 | type=int, 210 | default=512, 211 | help=( 212 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 213 | " resolution" 214 | ), 215 | ) 216 | parser.add_argument( 217 | "--center_crop", 218 | default=False, 219 | action="store_true", 220 | help=( 221 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 222 | " cropped. The images will be resized to the resolution first before cropping." 223 | ), 224 | ) 225 | parser.add_argument( 226 | "--train_text_encoder", 227 | action="store_true", 228 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", 229 | ) 230 | parser.add_argument( 231 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 232 | ) 233 | parser.add_argument( 234 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 235 | ) 236 | parser.add_argument("--num_train_epochs", type=int, default=1) 237 | parser.add_argument( 238 | "--max_train_steps", 239 | type=int, 240 | default=None, 241 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 242 | ) 243 | parser.add_argument( 244 | "--checkpointing_steps", 245 | type=int, 246 | default=500, 247 | help=( 248 | "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " 249 | "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." 250 | "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." 251 | "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" 252 | "instructions." 253 | ), 254 | ) 255 | parser.add_argument( 256 | "--checkpoints_total_limit", 257 | type=int, 258 | default=None, 259 | help=( 260 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." 261 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" 262 | " for more details" 263 | ), 264 | ) 265 | parser.add_argument( 266 | "--resume_from_checkpoint", 267 | type=str, 268 | default=None, 269 | help=( 270 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 271 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 272 | ), 273 | ) 274 | parser.add_argument( 275 | "--gradient_accumulation_steps", 276 | type=int, 277 | default=1, 278 | help="Number of updates steps to accumulate before performing a backward/update pass.", 279 | ) 280 | parser.add_argument( 281 | "--gradient_checkpointing", 282 | action="store_true", 283 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 284 | ) 285 | parser.add_argument( 286 | "--learning_rate", 287 | type=float, 288 | default=5e-6, 289 | help="Initial learning rate (after the potential warmup period) to use.", 290 | ) 291 | parser.add_argument( 292 | "--scale_lr", 293 | action="store_true", 294 | default=False, 295 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 296 | ) 297 | parser.add_argument( 298 | "--lr_scheduler", 299 | type=str, 300 | default="constant", 301 | help=( 302 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 303 | ' "constant", "constant_with_warmup"]' 304 | ), 305 | ) 306 | parser.add_argument( 307 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 308 | ) 309 | parser.add_argument( 310 | "--lr_num_cycles", 311 | type=int, 312 | default=1, 313 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 314 | ) 315 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 316 | parser.add_argument( 317 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 318 | ) 319 | parser.add_argument( 320 | "--dataloader_num_workers", 321 | type=int, 322 | default=0, 323 | help=( 324 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 325 | ), 326 | ) 327 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 328 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 329 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 330 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 331 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 332 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 333 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 334 | parser.add_argument( 335 | "--hub_model_id", 336 | type=str, 337 | default=None, 338 | help="The name of the repository to keep in sync with the local `output_dir`.", 339 | ) 340 | parser.add_argument( 341 | "--logging_dir", 342 | type=str, 343 | default="logs", 344 | help=( 345 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 346 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 347 | ), 348 | ) 349 | parser.add_argument( 350 | "--allow_tf32", 351 | action="store_true", 352 | help=( 353 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 354 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 355 | ), 356 | ) 357 | parser.add_argument( 358 | "--report_to", 359 | type=str, 360 | default="tensorboard", 361 | help=( 362 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 363 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 364 | ), 365 | ) 366 | parser.add_argument( 367 | "--validation_prompt", 368 | type=str, 369 | default=None, 370 | help="A prompt that is used during validation to verify that the model is learning.", 371 | ) 372 | parser.add_argument( 373 | "--num_validation_images", 374 | type=int, 375 | default=4, 376 | help="Number of images that should be generated during validation with `validation_prompt`.", 377 | ) 378 | parser.add_argument( 379 | "--validation_steps", 380 | type=int, 381 | default=100, 382 | help=( 383 | "Run validation every X steps. Validation consists of running the prompt" 384 | " `args.validation_prompt` multiple times: `args.num_validation_images`" 385 | " and logging the images." 386 | ), 387 | ) 388 | parser.add_argument( 389 | "--mixed_precision", 390 | type=str, 391 | default=None, 392 | choices=["no", "fp16", "bf16"], 393 | help=( 394 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 395 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 396 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 397 | ), 398 | ) 399 | parser.add_argument( 400 | "--prior_generation_precision", 401 | type=str, 402 | default=None, 403 | choices=["no", "fp32", "fp16", "bf16"], 404 | help=( 405 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 406 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." 407 | ), 408 | ) 409 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 410 | parser.add_argument( 411 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 412 | ) 413 | parser.add_argument( 414 | "--set_grads_to_none", 415 | action="store_true", 416 | help=( 417 | "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" 418 | " behaviors, so disable this argument if it causes any problems. More info:" 419 | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" 420 | ), 421 | ) 422 | 423 | parser.add_argument( 424 | "--offset_noise", 425 | action="store_true", 426 | default=False, 427 | help=( 428 | "Fine-tuning against a modified noise" 429 | " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information." 430 | ), 431 | ) 432 | parser.add_argument( 433 | "--instance_dataset", 434 | default="chexpert", 435 | choices=["chexpert", "retinopathy"], 436 | help="picking the instance dataset you want to use" 437 | ) 438 | 439 | if input_args is not None: 440 | args = parser.parse_args(input_args) 441 | else: 442 | args = parser.parse_args() 443 | 444 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 445 | if env_local_rank != -1 and env_local_rank != args.local_rank: 446 | args.local_rank = env_local_rank 447 | 448 | if args.with_prior_preservation: 449 | if args.class_data_dir is None: 450 | raise ValueError("You must specify a data directory for class images.") 451 | if args.class_prompt is None: 452 | raise ValueError("You must specify prompt for class images.") 453 | else: 454 | # logger is not available yet 455 | if args.class_data_dir is not None: 456 | warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") 457 | if args.class_prompt is not None: 458 | warnings.warn("You need not use --class_prompt without --with_prior_preservation.") 459 | 460 | return args 461 | 462 | 463 | class DreamBoothDataset(Dataset): 464 | """ 465 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 466 | It pre-processes the images and the tokenizes prompts. 467 | """ 468 | 469 | def __init__( 470 | self, 471 | instance_dataset, 472 | instance_prompt, 473 | tokenizer, 474 | class_data_root=None, 475 | class_prompt=None, 476 | class_num=None, 477 | size=512, 478 | center_crop=False, 479 | ): 480 | self.size = size 481 | self.center_crop = center_crop 482 | self.tokenizer = tokenizer 483 | 484 | self.instance_dataset = instance_dataset 485 | self.instance_prompt = instance_prompt 486 | 487 | if class_data_root is not None: 488 | self.class_data_root = Path(class_data_root) 489 | self.class_data_root.mkdir(parents=True, exist_ok=True) 490 | self.class_images_path = list(self.class_data_root.iterdir()) 491 | if class_num is not None: 492 | self.num_class_images = min(len(self.class_images_path), class_num) 493 | else: 494 | self.num_class_images = len(self.class_images_path) 495 | self.class_prompt = class_prompt 496 | else: 497 | self.class_data_root = None 498 | 499 | self.image_transforms = transforms.Compose( 500 | [ 501 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 502 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 503 | transforms.ToTensor(), 504 | transforms.Normalize([0.5], [0.5]), 505 | ] 506 | ) 507 | self.instance_dataset.transform = self.image_transforms 508 | 509 | def __len__(self): 510 | return self.instance_dataset.__len__() 511 | 512 | def __getitem__(self, index): 513 | example = {} 514 | instance_image, label, meta_data = self.instance_dataset.__getitem__(index) 515 | example["instance_images"] = instance_image 516 | example["instance_prompt_ids"] = self.tokenizer( 517 | self.instance_prompt + "," + meta_data["pathologies"], 518 | truncation=True, 519 | padding="max_length", 520 | max_length=self.tokenizer.model_max_length, 521 | return_tensors="pt", 522 | ).input_ids 523 | 524 | if self.class_data_root: 525 | class_image = Image.open(self.class_images_path[index % self.num_class_images]) 526 | if not class_image.mode == "RGB": 527 | class_image = class_image.convert("RGB") 528 | example["class_images"] = self.image_transforms(class_image) 529 | example["class_prompt_ids"] = self.tokenizer( 530 | self.class_prompt, 531 | truncation=True, 532 | padding="max_length", 533 | max_length=self.tokenizer.model_max_length, 534 | return_tensors="pt", 535 | ).input_ids 536 | 537 | return example 538 | 539 | 540 | def collate_fn(examples, with_prior_preservation=False): 541 | input_ids = [example["instance_prompt_ids"] for example in examples] 542 | pixel_values = [example["instance_images"] for example in examples] 543 | 544 | # Concat class and instance examples for prior preservation. 545 | # We do this to avoid doing two forward passes. 546 | if with_prior_preservation: 547 | input_ids += [example["class_prompt_ids"] for example in examples] 548 | pixel_values += [example["class_images"] for example in examples] 549 | 550 | pixel_values = torch.stack(pixel_values) 551 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 552 | 553 | input_ids = torch.cat(input_ids, dim=0) 554 | 555 | batch = { 556 | "input_ids": input_ids, 557 | "pixel_values": pixel_values, 558 | } 559 | return batch 560 | 561 | 562 | class PromptDataset(Dataset): 563 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 564 | 565 | def __init__(self, prompt, num_samples): 566 | self.prompt = prompt 567 | self.num_samples = num_samples 568 | 569 | def __len__(self): 570 | return self.num_samples 571 | 572 | def __getitem__(self, index): 573 | example = {} 574 | example["prompt"] = self.prompt 575 | example["index"] = index 576 | return example 577 | 578 | 579 | def main(args): 580 | logging_dir = Path(args.output_dir, args.logging_dir) 581 | 582 | accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) 583 | 584 | accelerator = Accelerator( 585 | gradient_accumulation_steps=args.gradient_accumulation_steps, 586 | mixed_precision=args.mixed_precision, 587 | log_with=args.report_to, 588 | logging_dir=logging_dir, 589 | project_config=accelerator_project_config, 590 | ) 591 | 592 | if args.report_to == "wandb": 593 | if not is_wandb_available(): 594 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 595 | 596 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 597 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 598 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 599 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 600 | raise ValueError( 601 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 602 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 603 | ) 604 | 605 | # Make one log on every process with the configuration for debugging. 606 | logging.basicConfig( 607 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 608 | datefmt="%m/%d/%Y %H:%M:%S", 609 | level=logging.INFO, 610 | ) 611 | logger.info(accelerator.state, main_process_only=False) 612 | if accelerator.is_local_main_process: 613 | transformers.utils.logging.set_verbosity_warning() 614 | diffusers.utils.logging.set_verbosity_info() 615 | else: 616 | transformers.utils.logging.set_verbosity_error() 617 | diffusers.utils.logging.set_verbosity_error() 618 | 619 | # If passed along, set the training seed now. 620 | if args.seed is not None: 621 | set_seed(args.seed) 622 | 623 | # Generate class images if prior preservation is enabled. 624 | if args.with_prior_preservation: 625 | class_images_dir = Path(args.class_data_dir) 626 | if not class_images_dir.exists(): 627 | class_images_dir.mkdir(parents=True) 628 | cur_class_images = len(list(class_images_dir.iterdir())) 629 | 630 | if cur_class_images < args.num_class_images: 631 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 632 | if args.prior_generation_precision == "fp32": 633 | torch_dtype = torch.float32 634 | elif args.prior_generation_precision == "fp16": 635 | torch_dtype = torch.float16 636 | elif args.prior_generation_precision == "bf16": 637 | torch_dtype = torch.bfloat16 638 | pipeline = DiffusionPipeline.from_pretrained( 639 | args.pretrained_model_name_or_path, 640 | torch_dtype=torch_dtype, 641 | safety_checker=None, 642 | revision=args.revision, 643 | ) 644 | pipeline.set_progress_bar_config(disable=True) 645 | 646 | num_new_images = args.num_class_images - cur_class_images 647 | logger.info(f"Number of class images to sample: {num_new_images}.") 648 | 649 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) 650 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) 651 | 652 | sample_dataloader = accelerator.prepare(sample_dataloader) 653 | pipeline.to(accelerator.device) 654 | 655 | for example in tqdm( 656 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 657 | ): 658 | images = pipeline(example["prompt"], height=args.resolution, width=args.resolution).images 659 | 660 | for i, image in enumerate(images): 661 | hash_image = hashlib.sha1(image.tobytes()).hexdigest() 662 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" 663 | image.save(image_filename) 664 | 665 | del pipeline 666 | if torch.cuda.is_available(): 667 | torch.cuda.empty_cache() 668 | 669 | # Handle the repository creation 670 | if accelerator.is_main_process: 671 | if args.output_dir is not None: 672 | os.makedirs(args.output_dir, exist_ok=True) 673 | 674 | if args.push_to_hub: 675 | repo_id = create_repo( 676 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 677 | ).repo_id 678 | 679 | # Load the tokenizer 680 | if args.tokenizer_name: 681 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) 682 | elif args.pretrained_model_name_or_path: 683 | tokenizer = AutoTokenizer.from_pretrained( 684 | args.pretrained_model_name_or_path, 685 | subfolder="tokenizer", 686 | revision=args.revision, 687 | use_fast=False, 688 | ) 689 | 690 | # import correct text encoder class 691 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) 692 | 693 | # Load scheduler and models 694 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 695 | text_encoder = text_encoder_cls.from_pretrained( 696 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 697 | ) 698 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) 699 | unet = UNet2DConditionModel.from_pretrained( 700 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision 701 | ) 702 | 703 | # `accelerate` 0.16.0 will have better support for customized saving 704 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 705 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 706 | def save_model_hook(models, weights, output_dir): 707 | for model in models: 708 | sub_dir = "unet" if type(model) == type(unet) else "text_encoder" 709 | model.save_pretrained(os.path.join(output_dir, sub_dir)) 710 | 711 | # make sure to pop weight so that corresponding model is not saved again 712 | weights.pop() 713 | 714 | def load_model_hook(models, input_dir): 715 | while len(models) > 0: 716 | # pop models so that they are not loaded again 717 | model = models.pop() 718 | 719 | if type(model) == type(text_encoder): 720 | # load transformers style into model 721 | load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") 722 | model.config = load_model.config 723 | else: 724 | # load diffusers style into model 725 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") 726 | model.register_to_config(**load_model.config) 727 | 728 | model.load_state_dict(load_model.state_dict()) 729 | del load_model 730 | 731 | accelerator.register_save_state_pre_hook(save_model_hook) 732 | accelerator.register_load_state_pre_hook(load_model_hook) 733 | 734 | vae.requires_grad_(False) 735 | if not args.train_text_encoder: 736 | text_encoder.requires_grad_(False) 737 | 738 | if args.enable_xformers_memory_efficient_attention: 739 | if is_xformers_available(): 740 | import xformers 741 | 742 | xformers_version = version.parse(xformers.__version__) 743 | if xformers_version == version.parse("0.0.16"): 744 | logger.warn( 745 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 746 | ) 747 | unet.enable_xformers_memory_efficient_attention() 748 | else: 749 | raise ValueError("xformers is not available. Make sure it is installed correctly") 750 | 751 | if args.gradient_checkpointing: 752 | unet.enable_gradient_checkpointing() 753 | if args.train_text_encoder: 754 | text_encoder.gradient_checkpointing_enable() 755 | 756 | # Check that all trainable models are in full precision 757 | low_precision_error_string = ( 758 | "Please make sure to always have all model weights in full float32 precision when starting training - even if" 759 | " doing mixed precision training. copy of the weights should still be float32." 760 | ) 761 | 762 | if accelerator.unwrap_model(unet).dtype != torch.float32: 763 | raise ValueError( 764 | f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" 765 | ) 766 | 767 | if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32: 768 | raise ValueError( 769 | f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." 770 | f" {low_precision_error_string}" 771 | ) 772 | 773 | # Enable TF32 for faster training on Ampere GPUs, 774 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 775 | if args.allow_tf32: 776 | torch.backends.cuda.matmul.allow_tf32 = True 777 | 778 | if args.scale_lr: 779 | args.learning_rate = ( 780 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 781 | ) 782 | 783 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 784 | if args.use_8bit_adam: 785 | try: 786 | import bitsandbytes as bnb 787 | except ImportError: 788 | raise ImportError( 789 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 790 | ) 791 | 792 | optimizer_class = bnb.optim.AdamW8bit 793 | else: 794 | optimizer_class = torch.optim.AdamW 795 | 796 | # Optimizer creation 797 | params_to_optimize = ( 798 | itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() 799 | ) 800 | optimizer = optimizer_class( 801 | params_to_optimize, 802 | lr=args.learning_rate, 803 | betas=(args.adam_beta1, args.adam_beta2), 804 | weight_decay=args.adam_weight_decay, 805 | eps=args.adam_epsilon, 806 | ) 807 | 808 | if args.instance_dataset == "chexpert": 809 | instance_dataset = ChexPert(args.instance_data_dir) 810 | elif args.instance_dataset == "retinopathy": 811 | instance_dataset = Retinopathy(args.instance_data_dir) 812 | elif args.instance_dataset == "isic": 813 | instance_dataset = ISIC(args.instance_data_dir) 814 | 815 | # Dataset and DataLoaders creation: 816 | train_dataset = DreamBoothDataset( 817 | instance_dataset=instance_dataset, 818 | instance_prompt=args.instance_prompt, 819 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, 820 | class_prompt=args.class_prompt, 821 | class_num=args.num_class_images, 822 | tokenizer=tokenizer, 823 | size=args.resolution, 824 | center_crop=args.center_crop, 825 | ) 826 | 827 | train_dataloader = torch.utils.data.DataLoader( 828 | train_dataset, 829 | batch_size=args.train_batch_size, 830 | shuffle=True, 831 | collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), 832 | num_workers=args.dataloader_num_workers, 833 | ) 834 | 835 | # Scheduler and math around the number of training steps. 836 | overrode_max_train_steps = False 837 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 838 | if args.max_train_steps is None: 839 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 840 | overrode_max_train_steps = True 841 | 842 | lr_scheduler = get_scheduler( 843 | args.lr_scheduler, 844 | optimizer=optimizer, 845 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 846 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 847 | num_cycles=args.lr_num_cycles, 848 | power=args.lr_power, 849 | ) 850 | 851 | # Prepare everything with our `accelerator`. 852 | if args.train_text_encoder: 853 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 854 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler 855 | ) 856 | else: 857 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 858 | unet, optimizer, train_dataloader, lr_scheduler 859 | ) 860 | 861 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 862 | # as these models are only used for inference, keeping weights in full precision is not required. 863 | weight_dtype = torch.float32 864 | if accelerator.mixed_precision == "fp16": 865 | weight_dtype = torch.float16 866 | elif accelerator.mixed_precision == "bf16": 867 | weight_dtype = torch.bfloat16 868 | 869 | # Move vae and text_encoder to device and cast to weight_dtype 870 | vae.to(accelerator.device, dtype=weight_dtype) 871 | if not args.train_text_encoder: 872 | text_encoder.to(accelerator.device, dtype=weight_dtype) 873 | 874 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 875 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 876 | if overrode_max_train_steps: 877 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 878 | # Afterwards we recalculate our number of training epochs 879 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 880 | 881 | # We need to initialize the trackers we use, and also store our configuration. 882 | # The trackers initializes automatically on the main process. 883 | if accelerator.is_main_process: 884 | accelerator.init_trackers("dreambooth", config=vars(args)) 885 | 886 | # Train! 887 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 888 | 889 | logger.info("***** Running training *****") 890 | logger.info(f" Num examples = {len(train_dataset)}") 891 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 892 | logger.info(f" Num Epochs = {args.num_train_epochs}") 893 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 894 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 895 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 896 | logger.info(f" Total optimization steps = {args.max_train_steps}") 897 | global_step = 0 898 | first_epoch = 0 899 | 900 | # Potentially load in the weights and states from a previous save 901 | if args.resume_from_checkpoint: 902 | if args.resume_from_checkpoint != "latest": 903 | path = os.path.basename(args.resume_from_checkpoint) 904 | else: 905 | # Get the mos recent checkpoint 906 | dirs = os.listdir(args.output_dir) 907 | dirs = [d for d in dirs if d.startswith("checkpoint")] 908 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 909 | path = dirs[-1] if len(dirs) > 0 else None 910 | 911 | if path is None: 912 | accelerator.print( 913 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 914 | ) 915 | args.resume_from_checkpoint = None 916 | else: 917 | accelerator.print(f"Resuming from checkpoint {path}") 918 | accelerator.load_state(os.path.join(args.output_dir, path)) 919 | global_step = int(path.split("-")[1]) 920 | 921 | resume_global_step = global_step * args.gradient_accumulation_steps 922 | first_epoch = global_step // num_update_steps_per_epoch 923 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) 924 | 925 | # Only show the progress bar once on each machine. 926 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 927 | progress_bar.set_description("Steps") 928 | 929 | for epoch in range(first_epoch, args.num_train_epochs): 930 | unet.train() 931 | if args.train_text_encoder: 932 | text_encoder.train() 933 | for step, batch in enumerate(train_dataloader): 934 | # Skip steps until we reach the resumed step 935 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 936 | if step % args.gradient_accumulation_steps == 0: 937 | progress_bar.update(1) 938 | continue 939 | 940 | with accelerator.accumulate(unet): 941 | # Convert images to latent space 942 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 943 | latents = latents * vae.config.scaling_factor 944 | 945 | # Sample noise that we'll add to the latents 946 | if args.offset_noise: 947 | noise = torch.randn_like(latents) + 0.1 * torch.randn( 948 | latents.shape[0], latents.shape[1], 1, 1, device=latents.device 949 | ) 950 | else: 951 | noise = torch.randn_like(latents) 952 | bsz = latents.shape[0] 953 | # Sample a random timestep for each image 954 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 955 | timesteps = timesteps.long() 956 | 957 | # Add noise to the latents according to the noise magnitude at each timestep 958 | # (this is the forward diffusion process) 959 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 960 | 961 | # Get the text embedding for conditioning 962 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 963 | 964 | # Predict the noise residual 965 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 966 | 967 | # Get the target for loss depending on the prediction type 968 | if noise_scheduler.config.prediction_type == "epsilon": 969 | target = noise 970 | elif noise_scheduler.config.prediction_type == "v_prediction": 971 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 972 | else: 973 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 974 | 975 | if args.with_prior_preservation: 976 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 977 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 978 | target, target_prior = torch.chunk(target, 2, dim=0) 979 | 980 | # Compute instance loss 981 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 982 | 983 | # Compute prior loss 984 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 985 | 986 | # Add the prior loss to the instance loss. 987 | loss = loss + args.prior_loss_weight * prior_loss 988 | else: 989 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 990 | 991 | accelerator.backward(loss) 992 | if accelerator.sync_gradients: 993 | params_to_clip = ( 994 | itertools.chain(unet.parameters(), text_encoder.parameters()) 995 | if args.train_text_encoder 996 | else unet.parameters() 997 | ) 998 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 999 | optimizer.step() 1000 | lr_scheduler.step() 1001 | optimizer.zero_grad(set_to_none=args.set_grads_to_none) 1002 | 1003 | # Checks if the accelerator has performed an optimization step behind the scenes 1004 | if accelerator.sync_gradients: 1005 | progress_bar.update(1) 1006 | global_step += 1 1007 | 1008 | if accelerator.is_main_process: 1009 | if global_step % args.checkpointing_steps == 0: 1010 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 1011 | accelerator.save_state(save_path) 1012 | logger.info(f"Saved state to {save_path}") 1013 | 1014 | if args.validation_prompt is not None and global_step % args.validation_steps == 0: 1015 | log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) 1016 | 1017 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 1018 | progress_bar.set_postfix(**logs) 1019 | accelerator.log(logs, step=global_step) 1020 | 1021 | if global_step >= args.max_train_steps: 1022 | break 1023 | 1024 | # Create the pipeline using using the trained modules and save it. 1025 | accelerator.wait_for_everyone() 1026 | if accelerator.is_main_process: 1027 | pipeline = DiffusionPipeline.from_pretrained( 1028 | args.pretrained_model_name_or_path, 1029 | unet=accelerator.unwrap_model(unet), 1030 | text_encoder=accelerator.unwrap_model(text_encoder), 1031 | revision=args.revision, 1032 | ) 1033 | pipeline.save_pretrained(args.output_dir) 1034 | 1035 | if args.push_to_hub: 1036 | upload_folder( 1037 | repo_id=repo_id, 1038 | folder_path=args.output_dir, 1039 | commit_message="End of training", 1040 | ignore_patterns=["step_*", "epoch_*"], 1041 | ) 1042 | 1043 | accelerator.end_training() 1044 | 1045 | 1046 | if __name__ == "__main__": 1047 | args = parse_args() 1048 | main(args) --------------------------------------------------------------------------------