├── .dockerignore ├── .gitignore ├── LICENSE ├── README.md ├── cog.yaml ├── dataset_and_utils.py ├── example_datasets ├── README.md ├── kiriko.png ├── kiriko │ ├── 0.src.jpg │ ├── 1.src.jpg │ ├── 10.src.jpg │ ├── 11.src.jpg │ ├── 12.src.jpg │ ├── 2.src.jpg │ ├── 3.src.jpg │ ├── 4.src.jpg │ ├── 5.src.jpg │ ├── 6.src.jpg │ ├── 7.src.jpg │ ├── 8.src.jpg │ └── 9.src.jpg ├── monster.png ├── monster │ ├── caption.csv │ ├── monstertoy (1).jpg │ ├── monstertoy (2).jpg │ ├── monstertoy (3).jpg │ ├── monstertoy (4).jpg │ └── monstertoy (5).jpg ├── monster_uni.png ├── zeke.zip ├── zeke │ ├── 0.src.jpg │ ├── 1.src.jpg │ ├── 2.src.jpg │ ├── 3.src.jpg │ ├── 4.src.jpg │ └── 5.src.jpg ├── zeke2 │ ├── 00.src.jpg │ ├── 01.src.jpg │ ├── 02.src.jpg │ ├── 03.src.jpg │ ├── 04.src.jpg │ ├── 05.src.jpg │ ├── 06.src.jpg │ ├── 07.src.jpg │ ├── 08.src.jpg │ ├── 09.src.jpg │ ├── 10.src.jpg │ ├── 12.src.jpg │ ├── 13.src.jpg │ ├── 14.src.jpg │ ├── 15.src.jpg │ ├── 16.src.jpg │ ├── 17.src.jpg │ ├── 18.src.jpg │ ├── 19.src.jpg │ ├── 20.src.jpg │ └── README.md └── zeke_unicorn.png ├── feature-extractor └── preprocessor_config.json ├── no_init.py ├── predict.py ├── preprocess.py ├── requirements_test.txt ├── samples.py ├── script ├── download_preprocessing_weights.py └── download_weights.py ├── tests ├── assets │ └── out.png ├── test_predict.py ├── test_remote_train.py └── test_utils.py ├── train.py ├── trainer_pti.py └── weights.py /.dockerignore: -------------------------------------------------------------------------------- 1 | sdxl-cache/ 2 | refiner-cache/ 3 | safety-cache/ 4 | trained-model/ 5 | *.png 6 | cache/ 7 | checkpoint/ 8 | training_out/ 9 | dreambooth/ 10 | lora/ 11 | ttemp/ 12 | .git/ 13 | cog_class_data/ 14 | dataset/ 15 | training_data/ 16 | temp/ 17 | temp_in/ 18 | cog_instance_data/ 19 | example_datasets/ 20 | trained_model.tar 21 | zeke_data.tar 22 | data.tar 23 | zeke.zip 24 | sketch-mountains-input.jpeg 25 | training_out* 26 | weights 27 | inference_* 28 | trained-model 29 | *.zip 30 | tmp/ 31 | blip-cache/ 32 | clipseg-cache/ 33 | swin2sr-cache/ 34 | weights-cache/ 35 | tests/ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | refiner-cache 3 | sdxl-cache 4 | safety-cache 5 | trained-model 6 | temp 7 | temp_in 8 | cache 9 | .cog 10 | __pycache__ 11 | wandb 12 | ft* 13 | *.ipynb 14 | dataset 15 | training_data 16 | training_out 17 | output* 18 | training_out* 19 | trained_model.tar 20 | checkpoint* 21 | weights 22 | __*.zip 23 | **-cache -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2023, Replicate, Inc. 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cog-SDXL 2 | 3 | [![Replicate demo and cloud API](https://replicate.com/stability-ai/sdxl/badge)](https://replicate.com/stability-ai/sdxl) 4 | 5 | This is an implementation of Stability AI's [SDXL](https://github.com/Stability-AI/generative-models) as a [Cog](https://github.com/replicate/cog) model. 6 | 7 | ## Development 8 | 9 | Follow the [model pushing guide](https://replicate.com/docs/guides/push-a-model) to push your own fork of SDXL to [Replicate](https://replicate.com). 10 | 11 | ## Basic Usage 12 | 13 | for prediction, 14 | 15 | ```bash 16 | cog predict -i prompt="a photo of TOK" 17 | ``` 18 | 19 | ```bash 20 | cog train -i input_images=@example_datasets/__data.zip -i use_face_detection_instead=True 21 | ``` 22 | 23 | ```bash 24 | cog run -p 5000 python -m cog.server.http 25 | ``` 26 | 27 | ## Update notes 28 | 29 | **2023-08-17** 30 | * ROI problem is fixed. 31 | * Now BLIP caption_prefix does not interfere with BLIP captioner. 32 | 33 | 34 | **2023-08-12** 35 | * Input types are inferred from input name extensions, or from the `input_images_filetype` argument 36 | * Preprocssing are now done with fp16, and if no mask is found, the model will use the whole image 37 | 38 | **2023-08-11** 39 | * Default to 768x768 resolution training 40 | * Rank as argument now, default to 32 41 | * Now uses Swin2SR `caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr` as default, and will upscale + downscale to 768x768 42 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | cuda: "11.8" 7 | python_version: "3.9" 8 | system_packages: 9 | - "libgl1-mesa-glx" 10 | - "ffmpeg" 11 | - "libsm6" 12 | - "libxext6" 13 | - "wget" 14 | python_packages: 15 | - "diffusers<=0.25" 16 | - "torch==2.0.1" 17 | - "transformers==4.31.0" 18 | - "invisible-watermark==0.2.0" 19 | - "accelerate==0.21.0" 20 | - "pandas==2.0.3" 21 | - "torchvision==0.15.2" 22 | - "numpy==1.25.1" 23 | - "pandas==2.0.3" 24 | - "fire==0.5.0" 25 | - "opencv-python>=4.1.0.25" 26 | - "mediapipe==0.10.2" 27 | 28 | run: 29 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)" && chmod +x /usr/local/bin/pget 30 | - wget http://thegiflibrary.tumblr.com/post/11565547760 -O face_landmarker_v2_with_blendshapes.task -q https://storage.googleapis.com/mediapipe-models/face_landmarker/face_landmarker/float16/1/face_landmarker.task 31 | 32 | predict: "predict.py:Predictor" 33 | train: "train.py:train" 34 | -------------------------------------------------------------------------------- /dataset_and_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List, Optional, Tuple 3 | 4 | import numpy as np 5 | import pandas as pd 6 | import PIL 7 | import torch 8 | import torch.utils.checkpoint 9 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 10 | from PIL import Image 11 | from safetensors import safe_open 12 | from safetensors.torch import save_file 13 | from torch.utils.data import Dataset 14 | from transformers import AutoTokenizer, PretrainedConfig 15 | 16 | 17 | def prepare_image( 18 | pil_image: PIL.Image.Image, w: int = 512, h: int = 512 19 | ) -> torch.Tensor: 20 | pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) 21 | arr = np.array(pil_image.convert("RGB")) 22 | arr = arr.astype(np.float32) / 127.5 - 1 23 | arr = np.transpose(arr, [2, 0, 1]) 24 | image = torch.from_numpy(arr).unsqueeze(0) 25 | return image 26 | 27 | 28 | def prepare_mask( 29 | pil_image: PIL.Image.Image, w: int = 512, h: int = 512 30 | ) -> torch.Tensor: 31 | pil_image = pil_image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1) 32 | arr = np.array(pil_image.convert("L")) 33 | arr = arr.astype(np.float32) / 255.0 34 | arr = np.expand_dims(arr, 0) 35 | image = torch.from_numpy(arr).unsqueeze(0) 36 | return image 37 | 38 | 39 | class PreprocessedDataset(Dataset): 40 | def __init__( 41 | self, 42 | csv_path: str, 43 | tokenizer_1, 44 | tokenizer_2, 45 | vae_encoder, 46 | text_encoder_1=None, 47 | text_encoder_2=None, 48 | do_cache: bool = False, 49 | size: int = 512, 50 | text_dropout: float = 0.0, 51 | scale_vae_latents: bool = True, 52 | substitute_caption_map: Dict[str, str] = {}, 53 | ): 54 | super().__init__() 55 | 56 | self.data = pd.read_csv(csv_path) 57 | self.csv_path = csv_path 58 | 59 | self.caption = self.data["caption"] 60 | # make it lowercase 61 | self.caption = self.caption.str.lower() 62 | for key, value in substitute_caption_map.items(): 63 | self.caption = self.caption.str.replace(key.lower(), value) 64 | 65 | self.image_path = self.data["image_path"] 66 | 67 | if "mask_path" not in self.data.columns: 68 | self.mask_path = None 69 | else: 70 | self.mask_path = self.data["mask_path"] 71 | 72 | if text_encoder_1 is None: 73 | self.return_text_embeddings = False 74 | else: 75 | self.text_encoder_1 = text_encoder_1 76 | self.text_encoder_2 = text_encoder_2 77 | self.return_text_embeddings = True 78 | assert ( 79 | NotImplementedError 80 | ), "Preprocessing Text Encoder is not implemented yet" 81 | 82 | self.tokenizer_1 = tokenizer_1 83 | self.tokenizer_2 = tokenizer_2 84 | 85 | self.vae_encoder = vae_encoder 86 | self.scale_vae_latents = scale_vae_latents 87 | self.text_dropout = text_dropout 88 | 89 | self.size = size 90 | 91 | if do_cache: 92 | self.vae_latents = [] 93 | self.tokens_tuple = [] 94 | self.masks = [] 95 | 96 | self.do_cache = True 97 | 98 | print("Captions to train on: ") 99 | for idx in range(len(self.data)): 100 | token, vae_latent, mask = self._process(idx) 101 | self.vae_latents.append(vae_latent) 102 | self.tokens_tuple.append(token) 103 | self.masks.append(mask) 104 | 105 | del self.vae_encoder 106 | 107 | else: 108 | self.do_cache = False 109 | 110 | @torch.no_grad() 111 | def _process( 112 | self, idx: int 113 | ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: 114 | image_path = self.image_path[idx] 115 | image_path = os.path.join(os.path.dirname(self.csv_path), image_path) 116 | 117 | image = PIL.Image.open(image_path).convert("RGB") 118 | image = prepare_image(image, self.size, self.size).to( 119 | dtype=self.vae_encoder.dtype, device=self.vae_encoder.device 120 | ) 121 | 122 | caption = self.caption[idx] 123 | 124 | print(caption) 125 | 126 | # tokenizer_1 127 | ti1 = self.tokenizer_1( 128 | caption, 129 | padding="max_length", 130 | max_length=77, 131 | truncation=True, 132 | add_special_tokens=True, 133 | return_tensors="pt", 134 | ).input_ids 135 | 136 | ti2 = self.tokenizer_2( 137 | caption, 138 | padding="max_length", 139 | max_length=77, 140 | truncation=True, 141 | add_special_tokens=True, 142 | return_tensors="pt", 143 | ).input_ids 144 | 145 | vae_latent = self.vae_encoder.encode(image).latent_dist.sample() 146 | 147 | if self.scale_vae_latents: 148 | vae_latent = vae_latent * self.vae_encoder.config.scaling_factor 149 | 150 | if self.mask_path is None: 151 | mask = torch.ones_like( 152 | vae_latent, dtype=self.vae_encoder.dtype, device=self.vae_encoder.device 153 | ) 154 | 155 | else: 156 | mask_path = self.mask_path[idx] 157 | mask_path = os.path.join(os.path.dirname(self.csv_path), mask_path) 158 | 159 | mask = PIL.Image.open(mask_path) 160 | mask = prepare_mask(mask, self.size, self.size).to( 161 | dtype=self.vae_encoder.dtype, device=self.vae_encoder.device 162 | ) 163 | 164 | mask = torch.nn.functional.interpolate( 165 | mask, size=(vae_latent.shape[-2], vae_latent.shape[-1]), mode="nearest" 166 | ) 167 | mask = mask.repeat(1, vae_latent.shape[1], 1, 1) 168 | 169 | assert len(mask.shape) == 4 and len(vae_latent.shape) == 4 170 | 171 | return (ti1.squeeze(), ti2.squeeze()), vae_latent.squeeze(), mask.squeeze() 172 | 173 | def __len__(self) -> int: 174 | return len(self.data) 175 | 176 | def atidx( 177 | self, idx: int 178 | ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: 179 | if self.do_cache: 180 | return self.tokens_tuple[idx], self.vae_latents[idx], self.masks[idx] 181 | else: 182 | return self._process(idx) 183 | 184 | def __getitem__( 185 | self, idx: int 186 | ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], torch.Tensor, torch.Tensor]: 187 | token, vae_latent, mask = self.atidx(idx) 188 | return token, vae_latent, mask 189 | 190 | 191 | def import_model_class_from_model_name_or_path( 192 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 193 | ): 194 | text_encoder_config = PretrainedConfig.from_pretrained( 195 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision 196 | ) 197 | model_class = text_encoder_config.architectures[0] 198 | 199 | if model_class == "CLIPTextModel": 200 | from transformers import CLIPTextModel 201 | 202 | return CLIPTextModel 203 | elif model_class == "CLIPTextModelWithProjection": 204 | from transformers import CLIPTextModelWithProjection 205 | 206 | return CLIPTextModelWithProjection 207 | else: 208 | raise ValueError(f"{model_class} is not supported.") 209 | 210 | 211 | def load_models(pretrained_model_name_or_path, revision, device, weight_dtype): 212 | tokenizer_one = AutoTokenizer.from_pretrained( 213 | pretrained_model_name_or_path, 214 | subfolder="tokenizer", 215 | revision=revision, 216 | use_fast=False, 217 | ) 218 | tokenizer_two = AutoTokenizer.from_pretrained( 219 | pretrained_model_name_or_path, 220 | subfolder="tokenizer_2", 221 | revision=revision, 222 | use_fast=False, 223 | ) 224 | 225 | # Load scheduler and models 226 | noise_scheduler = DDPMScheduler.from_pretrained( 227 | pretrained_model_name_or_path, subfolder="scheduler" 228 | ) 229 | # import correct text encoder classes 230 | text_encoder_cls_one = import_model_class_from_model_name_or_path( 231 | pretrained_model_name_or_path, revision 232 | ) 233 | text_encoder_cls_two = import_model_class_from_model_name_or_path( 234 | pretrained_model_name_or_path, revision, subfolder="text_encoder_2" 235 | ) 236 | text_encoder_one = text_encoder_cls_one.from_pretrained( 237 | pretrained_model_name_or_path, subfolder="text_encoder", revision=revision 238 | ) 239 | text_encoder_two = text_encoder_cls_two.from_pretrained( 240 | pretrained_model_name_or_path, subfolder="text_encoder_2", revision=revision 241 | ) 242 | 243 | vae = AutoencoderKL.from_pretrained( 244 | pretrained_model_name_or_path, subfolder="vae", revision=revision 245 | ) 246 | unet = UNet2DConditionModel.from_pretrained( 247 | pretrained_model_name_or_path, subfolder="unet", revision=revision 248 | ) 249 | 250 | vae.requires_grad_(False) 251 | text_encoder_one.requires_grad_(False) 252 | text_encoder_two.requires_grad_(False) 253 | 254 | unet.to(device, dtype=weight_dtype) 255 | vae.to(device, dtype=torch.float32) 256 | text_encoder_one.to(device, dtype=weight_dtype) 257 | text_encoder_two.to(device, dtype=weight_dtype) 258 | 259 | return ( 260 | tokenizer_one, 261 | tokenizer_two, 262 | noise_scheduler, 263 | text_encoder_one, 264 | text_encoder_two, 265 | vae, 266 | unet, 267 | ) 268 | 269 | 270 | def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]: 271 | """ 272 | Returns: 273 | a state dict containing just the attention processor parameters. 274 | """ 275 | attn_processors = unet.attn_processors 276 | 277 | attn_processors_state_dict = {} 278 | 279 | for attn_processor_key, attn_processor in attn_processors.items(): 280 | for parameter_key, parameter in attn_processor.state_dict().items(): 281 | attn_processors_state_dict[ 282 | f"{attn_processor_key}.{parameter_key}" 283 | ] = parameter 284 | 285 | return attn_processors_state_dict 286 | 287 | 288 | class TokenEmbeddingsHandler: 289 | def __init__(self, text_encoders, tokenizers): 290 | self.text_encoders = text_encoders 291 | self.tokenizers = tokenizers 292 | 293 | self.train_ids: Optional[torch.Tensor] = None 294 | self.inserting_toks: Optional[List[str]] = None 295 | self.embeddings_settings = {} 296 | 297 | def initialize_new_tokens(self, inserting_toks: List[str]): 298 | idx = 0 299 | for tokenizer, text_encoder in zip(self.tokenizers, self.text_encoders): 300 | assert isinstance( 301 | inserting_toks, list 302 | ), "inserting_toks should be a list of strings." 303 | assert all( 304 | isinstance(tok, str) for tok in inserting_toks 305 | ), "All elements in inserting_toks should be strings." 306 | 307 | self.inserting_toks = inserting_toks 308 | special_tokens_dict = {"additional_special_tokens": self.inserting_toks} 309 | tokenizer.add_special_tokens(special_tokens_dict) 310 | text_encoder.resize_token_embeddings(len(tokenizer)) 311 | 312 | self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) 313 | 314 | # random initialization of new tokens 315 | 316 | std_token_embedding = ( 317 | text_encoder.text_model.embeddings.token_embedding.weight.data.std() 318 | ) 319 | 320 | print(f"{idx} text encodedr's std_token_embedding: {std_token_embedding}") 321 | 322 | text_encoder.text_model.embeddings.token_embedding.weight.data[ 323 | self.train_ids 324 | ] = ( 325 | torch.randn( 326 | len(self.train_ids), text_encoder.text_model.config.hidden_size 327 | ) 328 | .to(device=self.device) 329 | .to(dtype=self.dtype) 330 | * std_token_embedding 331 | ) 332 | self.embeddings_settings[ 333 | f"original_embeddings_{idx}" 334 | ] = text_encoder.text_model.embeddings.token_embedding.weight.data.clone() 335 | self.embeddings_settings[f"std_token_embedding_{idx}"] = std_token_embedding 336 | 337 | inu = torch.ones((len(tokenizer),), dtype=torch.bool) 338 | inu[self.train_ids] = False 339 | 340 | self.embeddings_settings[f"index_no_updates_{idx}"] = inu 341 | 342 | print(self.embeddings_settings[f"index_no_updates_{idx}"].shape) 343 | 344 | idx += 1 345 | 346 | def save_embeddings(self, file_path: str): 347 | assert ( 348 | self.train_ids is not None 349 | ), "Initialize new tokens before saving embeddings." 350 | tensors = {} 351 | for idx, text_encoder in enumerate(self.text_encoders): 352 | assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[ 353 | 0 354 | ] == len(self.tokenizers[0]), "Tokenizers should be the same." 355 | new_token_embeddings = ( 356 | text_encoder.text_model.embeddings.token_embedding.weight.data[ 357 | self.train_ids 358 | ] 359 | ) 360 | tensors[f"text_encoders_{idx}"] = new_token_embeddings 361 | 362 | save_file(tensors, file_path) 363 | 364 | @property 365 | def dtype(self): 366 | return self.text_encoders[0].dtype 367 | 368 | @property 369 | def device(self): 370 | return self.text_encoders[0].device 371 | 372 | def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder): 373 | # Assuming new tokens are of the format 374 | self.inserting_toks = [f"" for i in range(loaded_embeddings.shape[0])] 375 | special_tokens_dict = {"additional_special_tokens": self.inserting_toks} 376 | tokenizer.add_special_tokens(special_tokens_dict) 377 | text_encoder.resize_token_embeddings(len(tokenizer)) 378 | 379 | self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks) 380 | assert self.train_ids is not None, "New tokens could not be converted to IDs." 381 | text_encoder.text_model.embeddings.token_embedding.weight.data[ 382 | self.train_ids 383 | ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype) 384 | 385 | @torch.no_grad() 386 | def retract_embeddings(self): 387 | for idx, text_encoder in enumerate(self.text_encoders): 388 | index_no_updates = self.embeddings_settings[f"index_no_updates_{idx}"] 389 | text_encoder.text_model.embeddings.token_embedding.weight.data[ 390 | index_no_updates 391 | ] = ( 392 | self.embeddings_settings[f"original_embeddings_{idx}"][index_no_updates] 393 | .to(device=text_encoder.device) 394 | .to(dtype=text_encoder.dtype) 395 | ) 396 | 397 | # for the parts that were updated, we need to normalize them 398 | # to have the same std as before 399 | std_token_embedding = self.embeddings_settings[f"std_token_embedding_{idx}"] 400 | 401 | index_updates = ~index_no_updates 402 | new_embeddings = ( 403 | text_encoder.text_model.embeddings.token_embedding.weight.data[ 404 | index_updates 405 | ] 406 | ) 407 | off_ratio = std_token_embedding / new_embeddings.std() 408 | 409 | new_embeddings = new_embeddings * (off_ratio**0.1) 410 | text_encoder.text_model.embeddings.token_embedding.weight.data[ 411 | index_updates 412 | ] = new_embeddings 413 | 414 | def load_embeddings(self, file_path: str): 415 | with safe_open(file_path, framework="pt", device=self.device.type) as f: 416 | for idx in range(len(self.text_encoders)): 417 | text_encoder = self.text_encoders[idx] 418 | tokenizer = self.tokenizers[idx] 419 | 420 | loaded_embeddings = f.get_tensor(f"text_encoders_{idx}") 421 | self._load_embeddings(loaded_embeddings, tokenizer, text_encoder) 422 | -------------------------------------------------------------------------------- /example_datasets/README.md: -------------------------------------------------------------------------------- 1 | ## Example Datasets 2 | 3 | This folder contains three example datasets that were used to tune SDXL using the Replicate API, along with (at the top level) example outputs generated from those datasets. -------------------------------------------------------------------------------- /example_datasets/kiriko.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko.png -------------------------------------------------------------------------------- /example_datasets/kiriko/0.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko/0.src.jpg -------------------------------------------------------------------------------- /example_datasets/kiriko/1.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko/1.src.jpg -------------------------------------------------------------------------------- /example_datasets/kiriko/10.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko/10.src.jpg -------------------------------------------------------------------------------- /example_datasets/kiriko/11.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko/11.src.jpg -------------------------------------------------------------------------------- /example_datasets/kiriko/12.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko/12.src.jpg -------------------------------------------------------------------------------- /example_datasets/kiriko/2.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko/2.src.jpg -------------------------------------------------------------------------------- /example_datasets/kiriko/3.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko/3.src.jpg -------------------------------------------------------------------------------- /example_datasets/kiriko/4.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko/4.src.jpg -------------------------------------------------------------------------------- /example_datasets/kiriko/5.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko/5.src.jpg -------------------------------------------------------------------------------- /example_datasets/kiriko/6.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko/6.src.jpg -------------------------------------------------------------------------------- /example_datasets/kiriko/7.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko/7.src.jpg -------------------------------------------------------------------------------- /example_datasets/kiriko/8.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko/8.src.jpg -------------------------------------------------------------------------------- /example_datasets/kiriko/9.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/kiriko/9.src.jpg -------------------------------------------------------------------------------- /example_datasets/monster.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/monster.png -------------------------------------------------------------------------------- /example_datasets/monster/caption.csv: -------------------------------------------------------------------------------- 1 | caption,image_file 2 | a TOK on a windowsill,monstertoy (1).jpg 3 | a photo of smiling TOK in an office,monstertoy (2).jpg 4 | a photo of TOK sitting by a window,monstertoy (3).jpg 5 | a photo of TOK on a car,monstertoy (4).jpg 6 | a photo of TOK smiling on the ground,monstertoy (5).jpg -------------------------------------------------------------------------------- /example_datasets/monster/monstertoy (1).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/monster/monstertoy (1).jpg -------------------------------------------------------------------------------- /example_datasets/monster/monstertoy (2).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/monster/monstertoy (2).jpg -------------------------------------------------------------------------------- /example_datasets/monster/monstertoy (3).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/monster/monstertoy (3).jpg -------------------------------------------------------------------------------- /example_datasets/monster/monstertoy (4).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/monster/monstertoy (4).jpg -------------------------------------------------------------------------------- /example_datasets/monster/monstertoy (5).jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/monster/monstertoy (5).jpg -------------------------------------------------------------------------------- /example_datasets/monster_uni.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/monster_uni.png -------------------------------------------------------------------------------- /example_datasets/zeke.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke.zip -------------------------------------------------------------------------------- /example_datasets/zeke/0.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke/0.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke/1.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke/1.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke/2.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke/2.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke/3.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke/3.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke/4.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke/4.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke/5.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke/5.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/00.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/00.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/01.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/01.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/02.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/02.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/03.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/03.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/04.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/04.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/05.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/05.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/06.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/06.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/07.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/07.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/08.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/08.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/09.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/09.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/10.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/10.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/12.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/12.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/13.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/13.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/14.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/14.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/15.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/15.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/16.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/16.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/17.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/17.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/18.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/18.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/19.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/19.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/20.src.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke2/20.src.jpg -------------------------------------------------------------------------------- /example_datasets/zeke2/README.md: -------------------------------------------------------------------------------- 1 | # Zeke 2 | 3 | This is a collection of images of [Zeke Sikelianos](https://github.com/zeke) which can be used to fine-tune image generation models. 4 | 5 | These images are provided for **non-commercial research purposes** in the field of AI image generation. They are intended to contribute to the development and improvement of AI models, particularly in areas such as facial recognition, emotion detection, and diverse representation in AI-generated content. 6 | 7 | ## Usage Guidelines 8 | 9 | - **Non-Commercial Use Only**: These images are strictly for non-commercial research purposes. Any commercial use is expressly prohibited. 10 | - **Ethical Considerations**: Users are expected to employ these images in a manner that respects human dignity and privacy. The images should not be used in ways that could be considered defamatory, offensive, or harmful. 11 | - **Accurate Representation**: When using these images to train AI models, please strive for accurate representation. Do not use the images in ways that misrepresent my identity, ethnicity, or personal characteristics. 12 | - **Attribution**: While not required, attribution is appreciated. 13 | - **No Redistribution**: The images in this collection should not be redistributed outside of your research project or shared publicly without explicit permission. 14 | - **Reporting**: I welcome feedback on how these images are being used in research. Please consider sharing your findings or insights derived from using this dataset. 15 | 16 | ## Content 17 | 18 | - All images are JPG 19 | - Dimensions are irregular 20 | - Images are generally under 1MB 21 | - I'm wearing a hat in some of the images. Hope that doesn't mess up your outputs! 22 | 23 | --- 24 | 25 | ![zeke images](https://github.com/user-attachments/assets/48d158e5-801b-4767-b546-0a43cfbf0994) -------------------------------------------------------------------------------- /example_datasets/zeke_unicorn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/example_datasets/zeke_unicorn.png -------------------------------------------------------------------------------- /feature-extractor/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": 224, 3 | "do_center_crop": true, 4 | "do_convert_rgb": true, 5 | "do_normalize": true, 6 | "do_resize": true, 7 | "feature_extractor_type": "CLIPFeatureExtractor", 8 | "image_mean": [ 9 | 0.48145466, 10 | 0.4578275, 11 | 0.40821073 12 | ], 13 | "image_std": [ 14 | 0.26862954, 15 | 0.26130258, 16 | 0.27577711 17 | ], 18 | "resample": 3, 19 | "size": 224 20 | } -------------------------------------------------------------------------------- /no_init.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import contextvars 3 | import threading 4 | from typing import ( 5 | Callable, 6 | ContextManager, 7 | NamedTuple, 8 | Optional, 9 | TypeVar, 10 | Union, 11 | ) 12 | 13 | import torch 14 | 15 | __all__ = ["no_init_or_tensor"] 16 | 17 | 18 | Model = TypeVar("Model") 19 | 20 | 21 | def no_init_or_tensor( 22 | loading_code: Optional[Callable[..., Model]] = None 23 | ) -> Union[Model, ContextManager]: 24 | """ 25 | Suppress the initialization of weights while loading a model. 26 | 27 | Can either directly be passed a callable containing model-loading code, 28 | which will be evaluated with weight initialization suppressed, 29 | or used as a context manager around arbitrary model-loading code. 30 | 31 | Args: 32 | loading_code: Either a callable to evaluate 33 | with model weight initialization suppressed, 34 | or None (the default) to use as a context manager. 35 | 36 | Returns: 37 | The return value of `loading_code`, if `loading_code` is callable. 38 | 39 | Otherwise, if `loading_code` is None, returns a context manager 40 | to be used in a `with`-statement. 41 | 42 | Examples: 43 | As a context manager:: 44 | 45 | from transformers import AutoConfig, AutoModelForCausalLM 46 | config = AutoConfig("EleutherAI/gpt-j-6B") 47 | with no_init_or_tensor(): 48 | model = AutoModelForCausalLM.from_config(config) 49 | 50 | Or, directly passing a callable:: 51 | 52 | from transformers import AutoConfig, AutoModelForCausalLM 53 | config = AutoConfig("EleutherAI/gpt-j-6B") 54 | model = no_init_or_tensor(lambda: AutoModelForCausalLM.from_config(config)) 55 | """ 56 | if loading_code is None: 57 | return _NoInitOrTensorImpl.context_manager() 58 | elif callable(loading_code): 59 | with _NoInitOrTensorImpl.context_manager(): 60 | return loading_code() 61 | else: 62 | raise TypeError( 63 | "no_init_or_tensor() expected a callable to evaluate," 64 | " or None if being used as a context manager;" 65 | f' got an object of type "{type(loading_code).__name__}" instead.' 66 | ) 67 | 68 | 69 | class _NoInitOrTensorImpl: 70 | # Implementation of the thread-safe, async-safe, re-entrant context manager 71 | # version of no_init_or_tensor(). 72 | # This class essentially acts as a namespace. 73 | # It is not instantiable, because modifications to torch functions 74 | # inherently affect the global scope, and thus there is no worthwhile data 75 | # to store in the class instance scope. 76 | _MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm) 77 | _MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES) 78 | _ORIGINAL_EMPTY = torch.empty 79 | 80 | is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", default=False) 81 | _count_active: int = 0 82 | _count_active_lock = threading.Lock() 83 | 84 | @classmethod 85 | @contextlib.contextmanager 86 | def context_manager(cls): 87 | if cls.is_active.get(): 88 | yield 89 | return 90 | 91 | with cls._count_active_lock: 92 | cls._count_active += 1 93 | if cls._count_active == 1: 94 | for mod in cls._MODULES: 95 | mod.reset_parameters = cls._disable(mod.reset_parameters) 96 | # When torch.empty is called, make it map to meta device by replacing 97 | # the device in kwargs. 98 | torch.empty = cls._ORIGINAL_EMPTY 99 | reset_token = cls.is_active.set(True) 100 | 101 | try: 102 | yield 103 | finally: 104 | cls.is_active.reset(reset_token) 105 | with cls._count_active_lock: 106 | cls._count_active -= 1 107 | if cls._count_active == 0: 108 | torch.empty = cls._ORIGINAL_EMPTY 109 | for mod, original in cls._MODULE_ORIGINALS: 110 | mod.reset_parameters = original 111 | 112 | @staticmethod 113 | def _disable(func): 114 | def wrapper(*args, **kwargs): 115 | # Behaves as normal except in an active context 116 | if not _NoInitOrTensorImpl.is_active.get(): 117 | return func(*args, **kwargs) 118 | 119 | return wrapper 120 | 121 | __init__ = None 122 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import json 3 | import os 4 | import shutil 5 | import subprocess 6 | import time 7 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 8 | from weights import WeightsDownloadCache 9 | 10 | import numpy as np 11 | import torch 12 | from cog import BasePredictor, Input, Path 13 | from diffusers import ( 14 | DDIMScheduler, 15 | DiffusionPipeline, 16 | DPMSolverMultistepScheduler, 17 | EulerAncestralDiscreteScheduler, 18 | EulerDiscreteScheduler, 19 | HeunDiscreteScheduler, 20 | PNDMScheduler, 21 | StableDiffusionXLImg2ImgPipeline, 22 | StableDiffusionXLInpaintPipeline, 23 | ) 24 | from diffusers.models.attention_processor import LoRAAttnProcessor2_0 25 | from diffusers.pipelines.stable_diffusion.safety_checker import ( 26 | StableDiffusionSafetyChecker, 27 | ) 28 | from diffusers.utils import load_image 29 | from safetensors import safe_open 30 | from safetensors.torch import load_file 31 | from transformers import CLIPImageProcessor 32 | 33 | from dataset_and_utils import TokenEmbeddingsHandler 34 | 35 | SDXL_MODEL_CACHE = "./sdxl-cache" 36 | REFINER_MODEL_CACHE = "./refiner-cache" 37 | SAFETY_CACHE = "./safety-cache" 38 | FEATURE_EXTRACTOR = "./feature-extractor" 39 | SDXL_URL = "https://weights.replicate.delivery/default/sdxl/sdxl-vae-upcast-fix.tar" 40 | REFINER_URL = ( 41 | "https://weights.replicate.delivery/default/sdxl/refiner-no-vae-no-encoder-1.0.tar" 42 | ) 43 | SAFETY_URL = "https://weights.replicate.delivery/default/sdxl/safety-1.0.tar" 44 | 45 | 46 | class KarrasDPM: 47 | def from_config(config): 48 | return DPMSolverMultistepScheduler.from_config(config, use_karras_sigmas=True) 49 | 50 | 51 | SCHEDULERS = { 52 | "DDIM": DDIMScheduler, 53 | "DPMSolverMultistep": DPMSolverMultistepScheduler, 54 | "HeunDiscrete": HeunDiscreteScheduler, 55 | "KarrasDPM": KarrasDPM, 56 | "K_EULER_ANCESTRAL": EulerAncestralDiscreteScheduler, 57 | "K_EULER": EulerDiscreteScheduler, 58 | "PNDM": PNDMScheduler, 59 | } 60 | 61 | 62 | def download_weights(url, dest): 63 | start = time.time() 64 | print("downloading url: ", url) 65 | print("downloading to: ", dest) 66 | subprocess.check_call(["pget", "-x", url, dest], close_fds=False) 67 | print("downloading took: ", time.time() - start) 68 | 69 | 70 | class Predictor(BasePredictor): 71 | def load_trained_weights(self, weights, pipe): 72 | from no_init import no_init_or_tensor 73 | 74 | # weights can be a URLPath, which behaves in unexpected ways 75 | weights = str(weights) 76 | if self.tuned_weights == weights: 77 | print("skipping loading .. weights already loaded") 78 | return 79 | 80 | # predictions can be cancelled while in this function, which 81 | # interrupts this finishing. To protect against odd states we 82 | # set tuned_weights to a value that lets the next prediction 83 | # know if it should try to load weights or if loading completed 84 | self.tuned_weights = 'loading' 85 | 86 | local_weights_cache = self.weights_cache.ensure(weights) 87 | 88 | # load UNET 89 | print("Loading fine-tuned model") 90 | self.is_lora = False 91 | 92 | maybe_unet_path = os.path.join(local_weights_cache, "unet.safetensors") 93 | if not os.path.exists(maybe_unet_path): 94 | print("Does not have Unet. assume we are using LoRA") 95 | self.is_lora = True 96 | 97 | if not self.is_lora: 98 | print("Loading Unet") 99 | 100 | new_unet_params = load_file( 101 | os.path.join(local_weights_cache, "unet.safetensors") 102 | ) 103 | # this should return _IncompatibleKeys(missing_keys=[...], unexpected_keys=[]) 104 | pipe.unet.load_state_dict(new_unet_params, strict=False) 105 | 106 | else: 107 | print("Loading Unet LoRA") 108 | 109 | unet = pipe.unet 110 | 111 | tensors = load_file(os.path.join(local_weights_cache, "lora.safetensors")) 112 | 113 | unet_lora_attn_procs = {} 114 | name_rank_map = {} 115 | for tk, tv in tensors.items(): 116 | # up is N, d 117 | tensors[tk] = tv.half() 118 | if tk.endswith("up.weight"): 119 | proc_name = ".".join(tk.split(".")[:-3]) 120 | r = tv.shape[1] 121 | name_rank_map[proc_name] = r 122 | 123 | for name, attn_processor in unet.attn_processors.items(): 124 | cross_attention_dim = ( 125 | None 126 | if name.endswith("attn1.processor") 127 | else unet.config.cross_attention_dim 128 | ) 129 | if name.startswith("mid_block"): 130 | hidden_size = unet.config.block_out_channels[-1] 131 | elif name.startswith("up_blocks"): 132 | block_id = int(name[len("up_blocks.")]) 133 | hidden_size = list(reversed(unet.config.block_out_channels))[ 134 | block_id 135 | ] 136 | elif name.startswith("down_blocks"): 137 | block_id = int(name[len("down_blocks.")]) 138 | hidden_size = unet.config.block_out_channels[block_id] 139 | with no_init_or_tensor(): 140 | module = LoRAAttnProcessor2_0( 141 | hidden_size=hidden_size, 142 | cross_attention_dim=cross_attention_dim, 143 | rank=name_rank_map[name], 144 | ).half() 145 | unet_lora_attn_procs[name] = module.to("cuda", non_blocking=True) 146 | 147 | unet.set_attn_processor(unet_lora_attn_procs) 148 | unet.load_state_dict(tensors, strict=False) 149 | 150 | # load text 151 | handler = TokenEmbeddingsHandler( 152 | [pipe.text_encoder, pipe.text_encoder_2], [pipe.tokenizer, pipe.tokenizer_2] 153 | ) 154 | handler.load_embeddings(os.path.join(local_weights_cache, "embeddings.pti")) 155 | 156 | # load params 157 | with open(os.path.join(local_weights_cache, "special_params.json"), "r") as f: 158 | params = json.load(f) 159 | 160 | self.token_map = params 161 | self.tuned_weights = weights 162 | self.tuned_model = True 163 | 164 | def unload_trained_weights(self, pipe: DiffusionPipeline): 165 | print("unloading loras") 166 | 167 | def _recursive_unset_lora(module: torch.nn.Module): 168 | if hasattr(module, "lora_layer"): 169 | module.lora_layer = None 170 | 171 | for _, child in module.named_children(): 172 | _recursive_unset_lora(child) 173 | 174 | _recursive_unset_lora(pipe.unet) 175 | self.tuned_weights = None 176 | self.tuned_model = False 177 | 178 | def setup(self, weights: Optional[Path] = None): 179 | """Load the model into memory to make running multiple predictions efficient""" 180 | 181 | start = time.time() 182 | self.tuned_model = False 183 | self.tuned_weights = None 184 | if str(weights) == "weights": 185 | weights = None 186 | 187 | self.weights_cache = WeightsDownloadCache() 188 | 189 | print("Loading safety checker...") 190 | if not os.path.exists(SAFETY_CACHE): 191 | download_weights(SAFETY_URL, SAFETY_CACHE) 192 | self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( 193 | SAFETY_CACHE, torch_dtype=torch.float16 194 | ).to("cuda") 195 | self.feature_extractor = CLIPImageProcessor.from_pretrained(FEATURE_EXTRACTOR) 196 | 197 | if not os.path.exists(SDXL_MODEL_CACHE): 198 | download_weights(SDXL_URL, SDXL_MODEL_CACHE) 199 | 200 | print("Loading sdxl txt2img pipeline...") 201 | self.txt2img_pipe = DiffusionPipeline.from_pretrained( 202 | SDXL_MODEL_CACHE, 203 | torch_dtype=torch.float16, 204 | use_safetensors=True, 205 | variant="fp16", 206 | ) 207 | self.is_lora = False 208 | if weights or os.path.exists("./trained-model"): 209 | self.load_trained_weights(weights, self.txt2img_pipe) 210 | 211 | self.txt2img_pipe.to("cuda") 212 | 213 | print("Loading SDXL img2img pipeline...") 214 | self.img2img_pipe = StableDiffusionXLImg2ImgPipeline( 215 | vae=self.txt2img_pipe.vae, 216 | text_encoder=self.txt2img_pipe.text_encoder, 217 | text_encoder_2=self.txt2img_pipe.text_encoder_2, 218 | tokenizer=self.txt2img_pipe.tokenizer, 219 | tokenizer_2=self.txt2img_pipe.tokenizer_2, 220 | unet=self.txt2img_pipe.unet, 221 | scheduler=self.txt2img_pipe.scheduler, 222 | ) 223 | self.img2img_pipe.to("cuda") 224 | 225 | print("Loading SDXL inpaint pipeline...") 226 | self.inpaint_pipe = StableDiffusionXLInpaintPipeline( 227 | vae=self.txt2img_pipe.vae, 228 | text_encoder=self.txt2img_pipe.text_encoder, 229 | text_encoder_2=self.txt2img_pipe.text_encoder_2, 230 | tokenizer=self.txt2img_pipe.tokenizer, 231 | tokenizer_2=self.txt2img_pipe.tokenizer_2, 232 | unet=self.txt2img_pipe.unet, 233 | scheduler=self.txt2img_pipe.scheduler, 234 | ) 235 | self.inpaint_pipe.to("cuda") 236 | 237 | print("Loading SDXL refiner pipeline...") 238 | # FIXME(ja): should the vae/text_encoder_2 be loaded from SDXL always? 239 | # - in the case of fine-tuned SDXL should we still? 240 | # FIXME(ja): if the answer to above is use VAE/Text_Encoder_2 from fine-tune 241 | # what does this imply about lora + refiner? does the refiner need to know about 242 | 243 | if not os.path.exists(REFINER_MODEL_CACHE): 244 | download_weights(REFINER_URL, REFINER_MODEL_CACHE) 245 | 246 | print("Loading refiner pipeline...") 247 | self.refiner = DiffusionPipeline.from_pretrained( 248 | REFINER_MODEL_CACHE, 249 | text_encoder_2=self.txt2img_pipe.text_encoder_2, 250 | vae=self.txt2img_pipe.vae, 251 | torch_dtype=torch.float16, 252 | use_safetensors=True, 253 | variant="fp16", 254 | ) 255 | self.refiner.to("cuda") 256 | print("setup took: ", time.time() - start) 257 | # self.txt2img_pipe.__class__.encode_prompt = new_encode_prompt 258 | 259 | def load_image(self, path): 260 | shutil.copyfile(path, "/tmp/image.png") 261 | return load_image("/tmp/image.png").convert("RGB") 262 | 263 | def run_safety_checker(self, image): 264 | safety_checker_input = self.feature_extractor(image, return_tensors="pt").to( 265 | "cuda" 266 | ) 267 | np_image = [np.array(val) for val in image] 268 | image, has_nsfw_concept = self.safety_checker( 269 | images=np_image, 270 | clip_input=safety_checker_input.pixel_values.to(torch.float16), 271 | ) 272 | return image, has_nsfw_concept 273 | 274 | @torch.inference_mode() 275 | def predict( 276 | self, 277 | prompt: str = Input( 278 | description="Input prompt", 279 | default="An astronaut riding a rainbow unicorn", 280 | ), 281 | negative_prompt: str = Input( 282 | description="Input Negative Prompt", 283 | default="", 284 | ), 285 | image: Path = Input( 286 | description="Input image for img2img or inpaint mode", 287 | default=None, 288 | ), 289 | mask: Path = Input( 290 | description="Input mask for inpaint mode. Black areas will be preserved, white areas will be inpainted.", 291 | default=None, 292 | ), 293 | width: int = Input( 294 | description="Width of output image", 295 | default=1024, 296 | ), 297 | height: int = Input( 298 | description="Height of output image", 299 | default=1024, 300 | ), 301 | num_outputs: int = Input( 302 | description="Number of images to output.", 303 | ge=1, 304 | le=4, 305 | default=1, 306 | ), 307 | scheduler: str = Input( 308 | description="scheduler", 309 | choices=SCHEDULERS.keys(), 310 | default="K_EULER", 311 | ), 312 | num_inference_steps: int = Input( 313 | description="Number of denoising steps", ge=1, le=500, default=50 314 | ), 315 | guidance_scale: float = Input( 316 | description="Scale for classifier-free guidance", ge=1, le=50, default=7.5 317 | ), 318 | prompt_strength: float = Input( 319 | description="Prompt strength when using img2img / inpaint. 1.0 corresponds to full destruction of information in image", 320 | ge=0.0, 321 | le=1.0, 322 | default=0.8, 323 | ), 324 | seed: int = Input( 325 | description="Random seed. Leave blank to randomize the seed", default=None 326 | ), 327 | refine: str = Input( 328 | description="Which refine style to use", 329 | choices=["no_refiner", "expert_ensemble_refiner", "base_image_refiner"], 330 | default="no_refiner", 331 | ), 332 | high_noise_frac: float = Input( 333 | description="For expert_ensemble_refiner, the fraction of noise to use", 334 | default=0.8, 335 | le=1.0, 336 | ge=0.0, 337 | ), 338 | refine_steps: int = Input( 339 | description="For base_image_refiner, the number of steps to refine, defaults to num_inference_steps", 340 | default=None, 341 | ), 342 | apply_watermark: bool = Input( 343 | description="Applies a watermark to enable determining if an image is generated in downstream applications. If you have other provisions for generating or deploying images safely, you can use this to disable watermarking.", 344 | default=True, 345 | ), 346 | lora_scale: float = Input( 347 | description="LoRA additive scale. Only applicable on trained models.", 348 | ge=0.0, 349 | le=1.0, 350 | default=0.6, 351 | ), 352 | replicate_weights: str = Input( 353 | description="Replicate LoRA weights to use. Leave blank to use the default weights.", 354 | default=None, 355 | ), 356 | disable_safety_checker: bool = Input( 357 | description="Disable safety checker for generated images. This feature is only available through the API. See [https://replicate.com/docs/how-does-replicate-work#safety](https://replicate.com/docs/how-does-replicate-work#safety)", 358 | default=False, 359 | ), 360 | ) -> List[Path]: 361 | """Run a single prediction on the model.""" 362 | if seed is None: 363 | seed = int.from_bytes(os.urandom(2), "big") 364 | print(f"Using seed: {seed}") 365 | 366 | if replicate_weights: 367 | self.load_trained_weights(replicate_weights, self.txt2img_pipe) 368 | elif self.tuned_model: 369 | self.unload_trained_weights(self.txt2img_pipe) 370 | 371 | # OOMs can leave vae in bad state 372 | if self.txt2img_pipe.vae.dtype == torch.float32: 373 | self.txt2img_pipe.vae.to(dtype=torch.float16) 374 | 375 | sdxl_kwargs = {} 376 | if self.tuned_model: 377 | # consistency with fine-tuning API 378 | for k, v in self.token_map.items(): 379 | prompt = prompt.replace(k, v) 380 | print(f"Prompt: {prompt}") 381 | if image and mask: 382 | print("inpainting mode") 383 | sdxl_kwargs["image"] = self.load_image(image) 384 | sdxl_kwargs["mask_image"] = self.load_image(mask) 385 | sdxl_kwargs["strength"] = prompt_strength 386 | sdxl_kwargs["width"] = width 387 | sdxl_kwargs["height"] = height 388 | pipe = self.inpaint_pipe 389 | elif image: 390 | print("img2img mode") 391 | sdxl_kwargs["image"] = self.load_image(image) 392 | sdxl_kwargs["strength"] = prompt_strength 393 | pipe = self.img2img_pipe 394 | else: 395 | print("txt2img mode") 396 | sdxl_kwargs["width"] = width 397 | sdxl_kwargs["height"] = height 398 | pipe = self.txt2img_pipe 399 | 400 | if refine == "expert_ensemble_refiner": 401 | sdxl_kwargs["output_type"] = "latent" 402 | sdxl_kwargs["denoising_end"] = high_noise_frac 403 | elif refine == "base_image_refiner": 404 | sdxl_kwargs["output_type"] = "latent" 405 | 406 | if not apply_watermark: 407 | # toggles watermark for this prediction 408 | watermark_cache = pipe.watermark 409 | pipe.watermark = None 410 | self.refiner.watermark = None 411 | 412 | pipe.scheduler = SCHEDULERS[scheduler].from_config(pipe.scheduler.config) 413 | generator = torch.Generator("cuda").manual_seed(seed) 414 | 415 | common_args = { 416 | "prompt": [prompt] * num_outputs, 417 | "negative_prompt": [negative_prompt] * num_outputs, 418 | "guidance_scale": guidance_scale, 419 | "generator": generator, 420 | "num_inference_steps": num_inference_steps, 421 | } 422 | 423 | if self.is_lora: 424 | sdxl_kwargs["cross_attention_kwargs"] = {"scale": lora_scale} 425 | 426 | output = pipe(**common_args, **sdxl_kwargs) 427 | 428 | if refine in ["expert_ensemble_refiner", "base_image_refiner"]: 429 | refiner_kwargs = { 430 | "image": output.images, 431 | } 432 | 433 | if refine == "expert_ensemble_refiner": 434 | refiner_kwargs["denoising_start"] = high_noise_frac 435 | if refine == "base_image_refiner" and refine_steps: 436 | common_args["num_inference_steps"] = refine_steps 437 | 438 | output = self.refiner(**common_args, **refiner_kwargs) 439 | 440 | if not apply_watermark: 441 | pipe.watermark = watermark_cache 442 | self.refiner.watermark = watermark_cache 443 | 444 | if not disable_safety_checker: 445 | _, has_nsfw_content = self.run_safety_checker(output.images) 446 | 447 | output_paths = [] 448 | for i, image in enumerate(output.images): 449 | if not disable_safety_checker: 450 | if has_nsfw_content[i]: 451 | print(f"NSFW content detected in image {i}") 452 | continue 453 | output_path = f"/tmp/out-{i}.png" 454 | image.save(output_path) 455 | output_paths.append(Path(output_path)) 456 | 457 | if len(output_paths) == 0: 458 | raise Exception( 459 | f"NSFW content detected. Try running it again, or try a different prompt." 460 | ) 461 | 462 | return output_paths 463 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # Have SwinIR upsample 2 | # Have BLIP auto caption 3 | # Have CLIPSeg auto mask concept 4 | 5 | import gc 6 | import fnmatch 7 | import mimetypes 8 | import os 9 | import re 10 | import shutil 11 | import tarfile 12 | from pathlib import Path 13 | from typing import List, Literal, Optional, Tuple, Union 14 | from zipfile import ZipFile 15 | 16 | import cv2 17 | import mediapipe as mp 18 | import numpy as np 19 | import pandas as pd 20 | import torch 21 | from PIL import Image, ImageFilter 22 | from tqdm import tqdm 23 | from transformers import ( 24 | BlipForConditionalGeneration, 25 | BlipProcessor, 26 | CLIPSegForImageSegmentation, 27 | CLIPSegProcessor, 28 | Swin2SRForImageSuperResolution, 29 | Swin2SRImageProcessor, 30 | ) 31 | 32 | from predict import download_weights 33 | 34 | # model is fixed to Salesforce/blip-image-captioning-large 35 | BLIP_URL = "https://weights.replicate.delivery/default/blip_large/blip_large.tar" 36 | BLIP_PROCESSOR_URL = ( 37 | "https://weights.replicate.delivery/default/blip_processor/blip_processor.tar" 38 | ) 39 | BLIP_PATH = "./blip-cache" 40 | BLIP_PROCESSOR_PATH = "./blip-proc-cache" 41 | 42 | # model is fixed to CIDAS/clipseg-rd64-refined 43 | CLIPSEG_URL = "https://weights.replicate.delivery/default/clip_seg_rd64_refined/clip_seg_rd64_refined.tar" 44 | CLIPSEG_PROCESSOR = "https://weights.replicate.delivery/default/clip_seg_processor/clip_seg_processor.tar" 45 | CLIPSEG_PATH = "./clipseg-cache" 46 | CLIPSEG_PROCESSOR_PATH = "./clipseg-proc-cache" 47 | 48 | # model is fixed to caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr 49 | SWIN2SR_URL = "https://weights.replicate.delivery/default/swin2sr_realworld_sr_x4_64_bsrgan_psnr/swin2sr_realworld_sr_x4_64_bsrgan_psnr.tar" 50 | SWIN2SR_PATH = "./swin2sr-cache" 51 | 52 | TEMP_OUT_DIR = "./temp/" 53 | TEMP_IN_DIR = "./temp_in/" 54 | 55 | CSV_MATCH = "caption" 56 | 57 | 58 | def preprocess( 59 | input_images_filetype: str, 60 | input_zip_path: Path, 61 | caption_text: str, 62 | mask_target_prompts: str, 63 | target_size: int, 64 | crop_based_on_salience: bool, 65 | use_face_detection_instead: bool, 66 | temp: float, 67 | substitution_tokens: List[str], 68 | ) -> Path: 69 | # assert str(files).endswith(".zip"), "files must be a zip file" 70 | 71 | # clear TEMP_IN_DIR first. 72 | 73 | for path in [TEMP_OUT_DIR, TEMP_IN_DIR]: 74 | if os.path.exists(path): 75 | shutil.rmtree(path) 76 | os.makedirs(path) 77 | 78 | caption_csv = None 79 | 80 | if input_images_filetype == "zip" or str(input_zip_path).endswith(".zip"): 81 | with ZipFile(str(input_zip_path), "r") as zip_ref: 82 | for zip_info in zip_ref.infolist(): 83 | if zip_info.filename[-1] == "/" or zip_info.filename.startswith( 84 | "__MACOSX" 85 | ): 86 | continue 87 | mt = mimetypes.guess_type(zip_info.filename) 88 | if mt and mt[0] and mt[0].startswith("image/"): 89 | zip_info.filename = os.path.basename(zip_info.filename) 90 | zip_ref.extract(zip_info, TEMP_IN_DIR) 91 | if ( 92 | mt 93 | and mt[0] 94 | and mt[0] == "text/csv" 95 | and CSV_MATCH in zip_info.filename 96 | ): 97 | zip_info.filename = os.path.basename(zip_info.filename) 98 | zip_ref.extract(zip_info, TEMP_IN_DIR) 99 | caption_csv = os.path.join(TEMP_IN_DIR, zip_info.filename) 100 | elif input_images_filetype == "tar" or str(input_zip_path).endswith(".tar"): 101 | assert str(input_zip_path).endswith( 102 | ".tar" 103 | ), "files must be a tar file if not zip" 104 | with tarfile.open(input_zip_path, "r") as tar_ref: 105 | for tar_info in tar_ref: 106 | if tar_info.name[-1] == "/" or tar_info.name.startswith("__MACOSX"): 107 | continue 108 | 109 | mt = mimetypes.guess_type(tar_info.name) 110 | if mt and mt[0] and mt[0].startswith("image/"): 111 | tar_info.name = os.path.basename(tar_info.name) 112 | tar_ref.extract(tar_info, TEMP_IN_DIR) 113 | if mt and mt[0] and mt[0] == "text/csv" and CSV_MATCH in tar_info.name: 114 | tar_info.name = os.path.basename(tar_info.name) 115 | tar_ref.extract(tar_info, TEMP_IN_DIR) 116 | caption_csv = os.path.join(TEMP_IN_DIR, tar_info.name) 117 | else: 118 | assert False, "input_images_filetype must be zip or tar" 119 | 120 | output_dir: str = TEMP_OUT_DIR 121 | 122 | load_and_save_masks_and_captions( 123 | files=TEMP_IN_DIR, 124 | output_dir=output_dir, 125 | caption_text=caption_text, 126 | caption_csv=caption_csv, 127 | mask_target_prompts=mask_target_prompts, 128 | target_size=target_size, 129 | crop_based_on_salience=crop_based_on_salience, 130 | use_face_detection_instead=use_face_detection_instead, 131 | temp=temp, 132 | substitution_tokens=substitution_tokens, 133 | ) 134 | 135 | return Path(TEMP_OUT_DIR) 136 | 137 | 138 | @torch.no_grad() 139 | @torch.cuda.amp.autocast() 140 | def swin_ir_sr( 141 | images: List[Image.Image], 142 | target_size: Optional[Tuple[int, int]] = None, 143 | device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), 144 | **kwargs, 145 | ) -> List[Image.Image]: 146 | """ 147 | Upscales images using SwinIR. Returns a list of PIL images. 148 | If the image is already larger than the target size, it will not be upscaled 149 | and will be returned as is. 150 | 151 | """ 152 | if not os.path.exists(SWIN2SR_PATH): 153 | download_weights(SWIN2SR_URL, SWIN2SR_PATH) 154 | model = Swin2SRForImageSuperResolution.from_pretrained(SWIN2SR_PATH).to(device) 155 | processor = Swin2SRImageProcessor() 156 | 157 | out_images = [] 158 | 159 | for image in tqdm(images): 160 | ori_w, ori_h = image.size 161 | if target_size is not None: 162 | if ori_w >= target_size[0] and ori_h >= target_size[1]: 163 | out_images.append(image) 164 | continue 165 | 166 | inputs = processor(image, return_tensors="pt").to(device) 167 | with torch.no_grad(): 168 | outputs = model(**inputs) 169 | 170 | output = ( 171 | outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() 172 | ) 173 | output = np.moveaxis(output, source=0, destination=-1) 174 | output = (output * 255.0).round().astype(np.uint8) 175 | output = Image.fromarray(output) 176 | 177 | out_images.append(output) 178 | 179 | return out_images 180 | 181 | 182 | @torch.no_grad() 183 | @torch.cuda.amp.autocast() 184 | def clipseg_mask_generator( 185 | images: List[Image.Image], 186 | target_prompts: Union[List[str], str], 187 | device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), 188 | bias: float = 0.01, 189 | temp: float = 1.0, 190 | **kwargs, 191 | ) -> List[Image.Image]: 192 | """ 193 | Returns a greyscale mask for each image, where the mask is the probability of the target prompt being present in the image 194 | """ 195 | 196 | if isinstance(target_prompts, str): 197 | print( 198 | f'Warning: only one target prompt "{target_prompts}" was given, so it will be used for all images' 199 | ) 200 | 201 | target_prompts = [target_prompts] * len(images) 202 | if not os.path.exists(CLIPSEG_PROCESSOR_PATH): 203 | download_weights(CLIPSEG_PROCESSOR, CLIPSEG_PROCESSOR_PATH) 204 | if not os.path.exists(CLIPSEG_PATH): 205 | download_weights(CLIPSEG_URL, CLIPSEG_PATH) 206 | processor = CLIPSegProcessor.from_pretrained(CLIPSEG_PROCESSOR_PATH) 207 | model = CLIPSegForImageSegmentation.from_pretrained(CLIPSEG_PATH).to(device) 208 | 209 | masks = [] 210 | 211 | for image, prompt in tqdm(zip(images, target_prompts)): 212 | original_size = image.size 213 | 214 | inputs = processor( 215 | text=[prompt, ""], 216 | images=[image] * 2, 217 | padding="max_length", 218 | truncation=True, 219 | return_tensors="pt", 220 | ).to(device) 221 | 222 | outputs = model(**inputs) 223 | 224 | logits = outputs.logits 225 | probs = torch.nn.functional.softmax(logits / temp, dim=0)[0] 226 | probs = (probs + bias).clamp_(0, 1) 227 | probs = 255 * probs / probs.max() 228 | 229 | # make mask greyscale 230 | mask = Image.fromarray(probs.cpu().numpy()).convert("L") 231 | 232 | # resize mask to original size 233 | mask = mask.resize(original_size) 234 | 235 | masks.append(mask) 236 | 237 | return masks 238 | 239 | 240 | @torch.no_grad() 241 | def blip_captioning_dataset( 242 | images: List[Image.Image], 243 | text: Optional[str] = None, 244 | device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), 245 | substitution_tokens: Optional[List[str]] = None, 246 | **kwargs, 247 | ) -> List[str]: 248 | """ 249 | Returns a list of captions for the given images 250 | """ 251 | if not os.path.exists(BLIP_PROCESSOR_PATH): 252 | download_weights(BLIP_PROCESSOR_URL, BLIP_PROCESSOR_PATH) 253 | if not os.path.exists(BLIP_PATH): 254 | download_weights(BLIP_URL, BLIP_PATH) 255 | processor = BlipProcessor.from_pretrained(BLIP_PROCESSOR_PATH) 256 | model = BlipForConditionalGeneration.from_pretrained(BLIP_PATH).to(device) 257 | captions = [] 258 | text = text.strip() 259 | print(f"Input captioning text: {text}") 260 | for image in tqdm(images): 261 | inputs = processor(image, return_tensors="pt").to("cuda") 262 | out = model.generate( 263 | **inputs, max_length=150, do_sample=True, top_k=50, temperature=0.7 264 | ) 265 | caption = processor.decode(out[0], skip_special_tokens=True) 266 | 267 | # BLIP 2 lowercases all caps tokens. This should properly replace them w/o messing up subwords. I'm sure there's a better way to do this. 268 | for token in substitution_tokens: 269 | print(token) 270 | sub_cap = " " + caption + " " 271 | print(sub_cap) 272 | sub_cap = sub_cap.replace(" " + token.lower() + " ", " " + token + " ") 273 | caption = sub_cap.strip() 274 | 275 | captions.append(text + " " + caption) 276 | print("Generated captions", captions) 277 | return captions 278 | 279 | 280 | def face_mask_google_mediapipe( 281 | images: List[Image.Image], blur_amount: float = 0.0, bias: float = 50.0 282 | ) -> List[Image.Image]: 283 | """ 284 | Returns a list of images with masks on the face parts. 285 | """ 286 | mp_face_detection = mp.solutions.face_detection 287 | mp_face_mesh = mp.solutions.face_mesh 288 | 289 | face_detection = mp_face_detection.FaceDetection( 290 | model_selection=1, min_detection_confidence=0.1 291 | ) 292 | face_mesh = mp_face_mesh.FaceMesh( 293 | static_image_mode=True, max_num_faces=1, min_detection_confidence=0.1 294 | ) 295 | 296 | masks = [] 297 | for image in tqdm(images): 298 | image_np = np.array(image) 299 | 300 | # Perform face detection 301 | results_detection = face_detection.process(image_np) 302 | ih, iw, _ = image_np.shape 303 | if results_detection.detections: 304 | for detection in results_detection.detections: 305 | bboxC = detection.location_data.relative_bounding_box 306 | 307 | bbox = ( 308 | int(bboxC.xmin * iw), 309 | int(bboxC.ymin * ih), 310 | int(bboxC.width * iw), 311 | int(bboxC.height * ih), 312 | ) 313 | 314 | # make sure bbox is within image 315 | bbox = ( 316 | max(0, bbox[0]), 317 | max(0, bbox[1]), 318 | min(iw - bbox[0], bbox[2]), 319 | min(ih - bbox[1], bbox[3]), 320 | ) 321 | 322 | print(bbox) 323 | 324 | # Extract face landmarks 325 | face_landmarks = face_mesh.process( 326 | image_np[bbox[1] : bbox[1] + bbox[3], bbox[0] : bbox[0] + bbox[2]] 327 | ).multi_face_landmarks 328 | 329 | # https://github.com/google/mediapipe/issues/1615 330 | # This was def helpful 331 | indexes = [ 332 | 10, 333 | 338, 334 | 297, 335 | 332, 336 | 284, 337 | 251, 338 | 389, 339 | 356, 340 | 454, 341 | 323, 342 | 361, 343 | 288, 344 | 397, 345 | 365, 346 | 379, 347 | 378, 348 | 400, 349 | 377, 350 | 152, 351 | 148, 352 | 176, 353 | 149, 354 | 150, 355 | 136, 356 | 172, 357 | 58, 358 | 132, 359 | 93, 360 | 234, 361 | 127, 362 | 162, 363 | 21, 364 | 54, 365 | 103, 366 | 67, 367 | 109, 368 | ] 369 | 370 | if face_landmarks: 371 | mask = Image.new("L", (iw, ih), 0) 372 | mask_np = np.array(mask) 373 | 374 | for face_landmark in face_landmarks: 375 | face_landmark = [face_landmark.landmark[idx] for idx in indexes] 376 | landmark_points = [ 377 | (int(l.x * bbox[2]) + bbox[0], int(l.y * bbox[3]) + bbox[1]) 378 | for l in face_landmark 379 | ] 380 | mask_np = cv2.fillPoly( 381 | mask_np, [np.array(landmark_points)], 255 382 | ) 383 | 384 | mask = Image.fromarray(mask_np) 385 | 386 | # Apply blur to the mask 387 | if blur_amount > 0: 388 | mask = mask.filter(ImageFilter.GaussianBlur(blur_amount)) 389 | 390 | # Apply bias to the mask 391 | if bias > 0: 392 | mask = np.array(mask) 393 | mask = mask + bias * np.ones(mask.shape, dtype=mask.dtype) 394 | mask = np.clip(mask, 0, 255) 395 | mask = Image.fromarray(mask) 396 | 397 | # Convert mask to 'L' mode (grayscale) before saving 398 | mask = mask.convert("L") 399 | 400 | masks.append(mask) 401 | else: 402 | # If face landmarks are not available, add a black mask of the same size as the image 403 | masks.append(Image.new("L", (iw, ih), 255)) 404 | 405 | else: 406 | print("No face detected, adding full mask") 407 | # If no face is detected, add a white mask of the same size as the image 408 | masks.append(Image.new("L", (iw, ih), 255)) 409 | 410 | return masks 411 | 412 | 413 | def _crop_to_square( 414 | image: Image.Image, com: List[Tuple[int, int]], resize_to: Optional[int] = None 415 | ): 416 | cx, cy = com 417 | width, height = image.size 418 | if width > height: 419 | left_possible = max(cx - height / 2, 0) 420 | left = min(left_possible, width - height) 421 | right = left + height 422 | top = 0 423 | bottom = height 424 | else: 425 | left = 0 426 | right = width 427 | top_possible = max(cy - width / 2, 0) 428 | top = min(top_possible, height - width) 429 | bottom = top + width 430 | 431 | image = image.crop((left, top, right, bottom)) 432 | 433 | if resize_to: 434 | image = image.resize((resize_to, resize_to), Image.Resampling.LANCZOS) 435 | 436 | return image 437 | 438 | 439 | def _center_of_mass(mask: Image.Image): 440 | """ 441 | Returns the center of mass of the mask 442 | """ 443 | x, y = np.meshgrid(np.arange(mask.size[0]), np.arange(mask.size[1])) 444 | mask_np = np.array(mask) + 0.01 445 | x_ = x * mask_np 446 | y_ = y * mask_np 447 | 448 | x = np.sum(x_) / np.sum(mask_np) 449 | y = np.sum(y_) / np.sum(mask_np) 450 | 451 | return x, y 452 | 453 | 454 | def load_and_save_masks_and_captions( 455 | files: Union[str, List[str]], 456 | output_dir: str = TEMP_OUT_DIR, 457 | caption_text: Optional[str] = None, 458 | caption_csv: Optional[str] = None, 459 | mask_target_prompts: Optional[Union[List[str], str]] = None, 460 | target_size: int = 1024, 461 | crop_based_on_salience: bool = True, 462 | use_face_detection_instead: bool = False, 463 | temp: float = 1.0, 464 | n_length: int = -1, 465 | substitution_tokens: Optional[List[str]] = None, 466 | ): 467 | """ 468 | Loads images from the given files, generates masks for them, and saves the masks and captions and upscale images 469 | to output dir. If mask_target_prompts is given, it will generate kinda-segmentation-masks for the prompts and save them as well. 470 | 471 | Example: 472 | >>> x = load_and_save_masks_and_captions( 473 | files="./data/images", 474 | output_dir="./data/masks_and_captions", 475 | caption_text="a photo of", 476 | mask_target_prompts="cat", 477 | target_size=768, 478 | crop_based_on_salience=True, 479 | use_face_detection_instead=False, 480 | temp=1.0, 481 | n_length=-1, 482 | ) 483 | """ 484 | os.makedirs(output_dir, exist_ok=True) 485 | 486 | # load images 487 | if isinstance(files, str): 488 | # check if it is a directory 489 | if os.path.isdir(files): 490 | # get all the .png .jpg in the directory 491 | files = ( 492 | _find_files("*.png", files) 493 | + _find_files("*.jpg", files) 494 | + _find_files("*.jpeg", files) 495 | ) 496 | 497 | if len(files) == 0: 498 | raise Exception( 499 | f"No files found in {files}. Either {files} is not a directory or it does not contain any .png or .jpg/jpeg files." 500 | ) 501 | if n_length == -1: 502 | n_length = len(files) 503 | files = sorted(files)[:n_length] 504 | print("Image files: ", files) 505 | images = [Image.open(file).convert("RGB") for file in files] 506 | 507 | # captions 508 | if caption_csv: 509 | print(f"Using provided captions") 510 | caption_df = pd.read_csv(caption_csv) 511 | # sort images to be consistent with 'sorted' above 512 | caption_df = caption_df.sort_values("image_file") 513 | captions = caption_df["caption"].values 514 | print("Captions: ", captions) 515 | if len(captions) != len(images): 516 | print("Not the same number of captions as images!") 517 | print(f"Num captions: {len(captions)}, Num images: {len(images)}") 518 | print("Captions: ", captions) 519 | print("Images: ", files) 520 | raise Exception( 521 | "Not the same number of captions as images! Check that all files passed in have a caption in your caption csv, and vice versa" 522 | ) 523 | 524 | else: 525 | print(f"Generating {len(images)} captions...") 526 | captions = blip_captioning_dataset( 527 | images, text=caption_text, substitution_tokens=substitution_tokens 528 | ) 529 | 530 | if mask_target_prompts is None: 531 | mask_target_prompts = "" 532 | temp = 999 533 | 534 | print(f"Generating {len(images)} masks...") 535 | if not use_face_detection_instead: 536 | seg_masks = clipseg_mask_generator( 537 | images=images, target_prompts=mask_target_prompts, temp=temp 538 | ) 539 | else: 540 | seg_masks = face_mask_google_mediapipe(images=images) 541 | 542 | # find the center of mass of the mask 543 | if crop_based_on_salience: 544 | coms = [_center_of_mass(mask) for mask in seg_masks] 545 | else: 546 | coms = [(image.size[0] / 2, image.size[1] / 2) for image in images] 547 | # based on the center of mass, crop the image to a square 548 | images = [ 549 | _crop_to_square(image, com, resize_to=None) for image, com in zip(images, coms) 550 | ] 551 | 552 | print(f"Upscaling {len(images)} images...") 553 | # upscale images anyways 554 | images = swin_ir_sr(images, target_size=(target_size, target_size)) 555 | images = [ 556 | image.resize((target_size, target_size), Image.Resampling.LANCZOS) 557 | for image in images 558 | ] 559 | 560 | seg_masks = [ 561 | _crop_to_square(mask, com, resize_to=target_size) 562 | for mask, com in zip(seg_masks, coms) 563 | ] 564 | 565 | data = [] 566 | 567 | # clean TEMP_OUT_DIR first 568 | if os.path.exists(output_dir): 569 | for file in os.listdir(output_dir): 570 | os.remove(os.path.join(output_dir, file)) 571 | 572 | os.makedirs(output_dir, exist_ok=True) 573 | 574 | # iterate through the images, masks, and captions and add a row to the dataframe for each 575 | for idx, (image, mask, caption) in enumerate(zip(images, seg_masks, captions)): 576 | image_name = f"{idx}.src.png" 577 | mask_file = f"{idx}.mask.png" 578 | 579 | # save the image and mask files 580 | image.save(output_dir + image_name) 581 | mask.save(output_dir + mask_file) 582 | 583 | # add a new row to the dataframe with the file names and caption 584 | data.append( 585 | {"image_path": image_name, "mask_path": mask_file, "caption": caption}, 586 | ) 587 | 588 | df = pd.DataFrame(columns=["image_path", "mask_path", "caption"], data=data) 589 | # save the dataframe to a CSV file 590 | df.to_csv(os.path.join(output_dir, "captions.csv"), index=False) 591 | 592 | 593 | def _find_files(pattern, dir="."): 594 | """Return list of files matching pattern in a given directory, in absolute format. 595 | Unlike glob, this is case-insensitive. 596 | """ 597 | 598 | rule = re.compile(fnmatch.translate(pattern), re.IGNORECASE) 599 | return [os.path.join(dir, f) for f in os.listdir(dir) if rule.match(f)] 600 | -------------------------------------------------------------------------------- /requirements_test.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | pytest 3 | replicate 4 | requests 5 | Pillow -------------------------------------------------------------------------------- /samples.py: -------------------------------------------------------------------------------- 1 | """ 2 | A handy utility for verifying SDXL image generation locally. 3 | To set up, first run a local cog server using: 4 | cog run -p 5000 python -m cog.server.http 5 | Then, in a separate terminal, generate samples 6 | python samples.py 7 | """ 8 | 9 | 10 | import base64 11 | import os 12 | import sys 13 | 14 | import requests 15 | 16 | 17 | def gen(output_fn, **kwargs): 18 | if os.path.exists(output_fn): 19 | return 20 | 21 | print("Generating", output_fn) 22 | url = "http://localhost:5000/predictions" 23 | response = requests.post(url, json={"input": kwargs}) 24 | data = response.json() 25 | 26 | try: 27 | datauri = data["output"][0] 28 | base64_encoded_data = datauri.split(",")[1] 29 | data = base64.b64decode(base64_encoded_data) 30 | except: 31 | print("Error!") 32 | print("input:", kwargs) 33 | print(data["logs"]) 34 | sys.exit(1) 35 | 36 | with open(output_fn, "wb") as f: 37 | f.write(data) 38 | 39 | 40 | def main(): 41 | SCHEDULERS = [ 42 | "DDIM", 43 | "DPMSolverMultistep", 44 | "HeunDiscrete", 45 | "KarrasDPM", 46 | "K_EULER_ANCESTRAL", 47 | "K_EULER", 48 | "PNDM", 49 | ] 50 | 51 | gen( 52 | f"sample.txt2img.png", 53 | prompt="A studio portrait photo of a cat", 54 | num_inference_steps=25, 55 | guidance_scale=7, 56 | negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", 57 | seed=1000, 58 | width=1024, 59 | height=1024, 60 | ) 61 | 62 | for refiner in ["base_image_refiner", "expert_ensemble_refiner", "no_refiner"]: 63 | gen( 64 | f"sample.img2img.{refiner}.png", 65 | prompt="a photo of an astronaut riding a horse on mars", 66 | image="https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png", 67 | prompt_strength=0.8, 68 | num_inference_steps=25, 69 | refine=refiner, 70 | guidance_scale=7, 71 | negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", 72 | seed=42, 73 | ) 74 | 75 | gen( 76 | f"sample.inpaint.{refiner}.png", 77 | prompt="A majestic tiger sitting on a bench", 78 | image="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png", 79 | mask="https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png", 80 | prompt_strength=0.8, 81 | num_inference_steps=25, 82 | refine=refiner, 83 | guidance_scale=7, 84 | negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", 85 | seed=42, 86 | ) 87 | 88 | for split in range(0, 10): 89 | split = split / 10.0 90 | gen( 91 | f"sample.expert_ensemble_refiner.{split}.txt2img.png", 92 | prompt="A studio portrait photo of a cat", 93 | num_inference_steps=25, 94 | guidance_scale=7, 95 | refine="expert_ensemble_refiner", 96 | high_noise_frac=split, 97 | negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", 98 | seed=1000, 99 | width=1024, 100 | height=1024, 101 | ) 102 | 103 | gen( 104 | f"sample.refine.txt2img.png", 105 | prompt="A studio portrait photo of a cat", 106 | num_inference_steps=25, 107 | guidance_scale=7, 108 | refine="base_image_refiner", 109 | negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", 110 | seed=1000, 111 | width=1024, 112 | height=1024, 113 | ) 114 | gen( 115 | f"sample.refine.10.txt2img.png", 116 | prompt="A studio portrait photo of a cat", 117 | num_inference_steps=25, 118 | guidance_scale=7, 119 | refine="base_image_refiner", 120 | refine_steps=10, 121 | negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", 122 | seed=1000, 123 | width=1024, 124 | height=1024, 125 | ) 126 | 127 | gen( 128 | "samples.2.txt2img.png", 129 | prompt="A studio portrait photo of a cat", 130 | num_inference_steps=25, 131 | guidance_scale=7, 132 | negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", 133 | scheduler="KarrasDPM", 134 | num_outputs=2, 135 | seed=1000, 136 | width=1024, 137 | height=1024, 138 | ) 139 | 140 | for s in SCHEDULERS: 141 | gen( 142 | f"sample.{s}.txt2img.png", 143 | prompt="A studio portrait photo of a cat", 144 | num_inference_steps=25, 145 | guidance_scale=7, 146 | negative_prompt="ugly, soft, blurry, out of focus, low quality, garish, distorted, disfigured", 147 | scheduler=s, 148 | seed=1000, 149 | width=1024, 150 | height=1024, 151 | ) 152 | 153 | 154 | if __name__ == "__main__": 155 | main() 156 | -------------------------------------------------------------------------------- /script/download_preprocessing_weights.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | 5 | from transformers import ( 6 | BlipForConditionalGeneration, 7 | BlipProcessor, 8 | CLIPSegForImageSegmentation, 9 | CLIPSegProcessor, 10 | Swin2SRForImageSuperResolution, 11 | ) 12 | 13 | DEFAULT_BLIP = "Salesforce/blip-image-captioning-large" 14 | DEFAULT_CLIPSEG = "CIDAS/clipseg-rd64-refined" 15 | DEFAULT_SWINIR = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr" 16 | 17 | 18 | def upload(args): 19 | blip_processor = BlipProcessor.from_pretrained(DEFAULT_BLIP) 20 | blip_model = BlipForConditionalGeneration.from_pretrained(DEFAULT_BLIP) 21 | 22 | clip_processor = CLIPSegProcessor.from_pretrained(DEFAULT_CLIPSEG) 23 | clip_model = CLIPSegForImageSegmentation.from_pretrained(DEFAULT_CLIPSEG) 24 | 25 | swin_model = Swin2SRForImageSuperResolution.from_pretrained(DEFAULT_SWINIR) 26 | 27 | temp_models = "tmp/models" 28 | if os.path.exists(temp_models): 29 | shutil.rmtree(temp_models) 30 | os.makedirs(temp_models) 31 | 32 | blip_processor.save_pretrained(os.path.join(temp_models, "blip_processor")) 33 | blip_model.save_pretrained(os.path.join(temp_models, "blip_large")) 34 | clip_processor.save_pretrained(os.path.join(temp_models, "clip_seg_processor")) 35 | clip_model.save_pretrained(os.path.join(temp_models, "clip_seg_rd64_refined")) 36 | swin_model.save_pretrained( 37 | os.path.join(temp_models, "swin2sr_realworld_sr_x4_64_bsrgan_psnr") 38 | ) 39 | 40 | for val in os.listdir(temp_models): 41 | if "tar" not in val: 42 | os.system( 43 | f"sudo tar -cvf {os.path.join(temp_models, val)}.tar -C {os.path.join(temp_models, val)} ." 44 | ) 45 | os.system( 46 | f"gcloud storage cp -R {os.path.join(temp_models, val)}.tar gs://{args.bucket}/{val}/" 47 | ) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--bucket", "-m", type=str) 53 | args = parser.parse_args() 54 | upload(args) 55 | -------------------------------------------------------------------------------- /script/download_weights.py: -------------------------------------------------------------------------------- 1 | # Run this before you deploy it on replicate, because if you don't 2 | # whenever you run the model, it will download the weights from the 3 | # internet, which will take a long time. 4 | 5 | import torch 6 | from diffusers import AutoencoderKL, DiffusionPipeline 7 | from diffusers.pipelines.stable_diffusion.safety_checker import ( 8 | StableDiffusionSafetyChecker, 9 | ) 10 | 11 | # pipe = DiffusionPipeline.from_pretrained( 12 | # "stabilityai/stable-diffusion-xl-base-1.0", 13 | # torch_dtype=torch.float16, 14 | # use_safetensors=True, 15 | # variant="fp16", 16 | # ) 17 | 18 | # pipe.save_pretrained("./cache", safe_serialization=True) 19 | 20 | better_vae = AutoencoderKL.from_pretrained( 21 | "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 22 | ) 23 | 24 | pipe = DiffusionPipeline.from_pretrained( 25 | "stabilityai/stable-diffusion-xl-base-1.0", 26 | vae=better_vae, 27 | torch_dtype=torch.float16, 28 | use_safetensors=True, 29 | variant="fp16", 30 | ) 31 | 32 | pipe.save_pretrained("./sdxl-cache", safe_serialization=True) 33 | 34 | pipe = DiffusionPipeline.from_pretrained( 35 | "stabilityai/stable-diffusion-xl-refiner-1.0", 36 | torch_dtype=torch.float16, 37 | use_safetensors=True, 38 | variant="fp16", 39 | ) 40 | 41 | # TODO - we don't need to save all of this and in fact should save just the unet, tokenizer, and config. 42 | pipe.save_pretrained("./refiner-cache", safe_serialization=True) 43 | 44 | 45 | safety = StableDiffusionSafetyChecker.from_pretrained( 46 | "CompVis/stable-diffusion-safety-checker", 47 | torch_dtype=torch.float16, 48 | ) 49 | 50 | safety.save_pretrained("./safety-cache") 51 | -------------------------------------------------------------------------------- /tests/assets/out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/replicate/cog-sdxl/c79cb3c9c3c2968ee53ada00c7fb02aa3b9fc58c/tests/assets/out.png -------------------------------------------------------------------------------- /tests/test_predict.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | import pickle 4 | import subprocess 5 | import sys 6 | import time 7 | from functools import partial 8 | from io import BytesIO 9 | 10 | import numpy as np 11 | import pytest 12 | import replicate 13 | import requests 14 | from PIL import Image, ImageChops 15 | 16 | ENV = os.getenv('TEST_ENV', 'local') 17 | LOCAL_ENDPOINT = "http://localhost:5000/predictions" 18 | MODEL = os.getenv('STAGING_MODEL', 'no model configured') 19 | 20 | def local_run(model_endpoint: str, model_input: dict): 21 | response = requests.post(model_endpoint, json={"input": model_input}) 22 | data = response.json() 23 | 24 | try: 25 | # TODO: this will break if we test batching 26 | datauri = data["output"][0] 27 | base64_encoded_data = datauri.split(",")[1] 28 | data = base64.b64decode(base64_encoded_data) 29 | return Image.open(BytesIO(data)) 30 | except Exception as e: 31 | print("Error!") 32 | print("input:", model_input) 33 | print(data["logs"]) 34 | raise e 35 | 36 | 37 | def replicate_run(model: str, version: str, model_input: dict): 38 | output = replicate.run( 39 | f"{model}:{version}", 40 | input=model_input) 41 | url = output[0] 42 | 43 | response = requests.get(url) 44 | return Image.open(BytesIO(response.content)) 45 | 46 | 47 | def wait_for_server_to_be_ready(url, timeout=300): 48 | """ 49 | Waits for the server to be ready. 50 | 51 | Args: 52 | - url: The health check URL to poll. 53 | - timeout: Maximum time (in seconds) to wait for the server to be ready. 54 | """ 55 | start_time = time.time() 56 | while True: 57 | try: 58 | response = requests.get(url) 59 | data = response.json() 60 | 61 | if data["status"] == "READY": 62 | return 63 | elif data["status"] == "SETUP_FAILED": 64 | raise RuntimeError( 65 | "Server initialization failed with status: SETUP_FAILED" 66 | ) 67 | 68 | except requests.RequestException: 69 | pass 70 | 71 | if time.time() - start_time > timeout: 72 | raise TimeoutError("Server did not become ready in the expected time.") 73 | 74 | time.sleep(5) # Poll every 5 seconds 75 | 76 | 77 | @pytest.fixture(scope="session") 78 | def inference_func(): 79 | """ 80 | local inference uses http API to hit local server; staging inference uses python API b/c it's cleaner. 81 | """ 82 | if ENV == 'local': 83 | return partial(local_run, LOCAL_ENDPOINT) 84 | elif ENV == 'staging': 85 | model = replicate.models.get(MODEL) 86 | print(f"model,", model) 87 | version = model.versions.list()[0] 88 | return partial(replicate_run, MODEL, version.id) 89 | else: 90 | raise Exception(f"env should be local or staging but was {ENV}") 91 | 92 | 93 | @pytest.fixture(scope="session", autouse=True) 94 | def service(): 95 | """ 96 | Spins up local cog server to hit for tests if running locally, no-op otherwise 97 | """ 98 | if ENV == 'local': 99 | print("building model") 100 | # starts local server if we're running things locally 101 | build_command = 'cog build -t test-model'.split() 102 | subprocess.run(build_command, check=True) 103 | container_name = 'cog-test' 104 | try: 105 | subprocess.check_output(['docker', 'inspect', '--format="{{.State.Running}}"', container_name]) 106 | print(f"Container '{container_name}' is running. Stopping and removing...") 107 | subprocess.check_call(['docker', 'stop', container_name]) 108 | subprocess.check_call(['docker', 'rm', container_name]) 109 | print(f"Container '{container_name}' stopped and removed.") 110 | except subprocess.CalledProcessError: 111 | # Container not found 112 | print(f"Container '{container_name}' not found or not running.") 113 | 114 | run_command = f'docker run -d -p 5000:5000 --gpus all --name {container_name} test-model '.split() 115 | process = subprocess.Popen(run_command, stdout=sys.stdout, stderr=sys.stderr) 116 | 117 | wait_for_server_to_be_ready("http://localhost:5000/health-check") 118 | 119 | yield 120 | process.terminate() 121 | process.wait() 122 | stop_command = "docker stop cog-test".split() 123 | subprocess.run(stop_command) 124 | else: 125 | yield 126 | 127 | 128 | def image_equal_fuzzy(img_expected, img_actual, test_name='default', tol=20): 129 | """ 130 | Assert that average pixel values differ by less than tol across image 131 | Tol determined empirically - holding everything else equal but varying seed 132 | generates images that vary by at least 50 133 | """ 134 | img1 = np.array(img_expected, dtype=np.int32) 135 | img2 = np.array(img_actual, dtype=np.int32) 136 | 137 | mean_delta = np.mean(np.abs(img1 - img2)) 138 | imgs_equal = (mean_delta < tol) 139 | if not imgs_equal: 140 | # save failures for quick inspection 141 | save_dir = f"tmp/{test_name}" 142 | if not os.path.exists(save_dir): 143 | os.makedirs(save_dir) 144 | img_expected.save(os.path.join(save_dir, 'expected.png')) 145 | img_actual.save(os.path.join(save_dir, 'actual.png')) 146 | difference = ImageChops.difference(img_expected, img_actual) 147 | difference.save(os.path.join(save_dir, 'delta.png')) 148 | 149 | return imgs_equal 150 | 151 | 152 | def test_seeded_prediction(inference_func, request): 153 | """ 154 | SDXL w/seed should be deterministic. may need to adjust tolerance for optimized SDXLs 155 | """ 156 | data = { 157 | "prompt": "An astronaut riding a rainbow unicorn, cinematic, dramatic", 158 | "num_inference_steps": 50, 159 | "width": 1024, 160 | "height": 1024, 161 | "scheduler": "DDIM", 162 | "refine": "expert_ensemble_refiner", 163 | "seed": 12103, 164 | } 165 | actual_image = inference_func(data) 166 | expected_image = Image.open("tests/assets/out.png") 167 | assert image_equal_fuzzy(actual_image, expected_image, test_name=request.node.name) 168 | 169 | 170 | def test_lora_load_unload(inference_func, request): 171 | """ 172 | Tests generation with & without loras. 173 | This is checking for some gnarly state issues (can SDXL load / unload LoRAs), so predictions need to run in series. 174 | """ 175 | SEED = 1234 176 | base_data = { 177 | "prompt": "A photo of a dog on the beach", 178 | "num_inference_steps": 50, 179 | # Add other parameters here 180 | "seed": SEED, 181 | } 182 | base_img_1 = inference_func(base_data) 183 | 184 | lora_a_data = { 185 | "prompt": "A photo of a TOK on the beach", 186 | "num_inference_steps": 50, 187 | # Add other parameters here 188 | "replicate_weights": "https://storage.googleapis.com/dan-scratch-public/sdxl/other_model.tar", 189 | "seed": SEED 190 | } 191 | lora_a_img_1 = inference_func(lora_a_data) 192 | assert not image_equal_fuzzy(lora_a_img_1, base_img_1, test_name=request.node.name) 193 | 194 | lora_a_img_2 = inference_func(lora_a_data) 195 | assert image_equal_fuzzy(lora_a_img_1, lora_a_img_2, test_name=request.node.name) 196 | 197 | lora_b_data = { 198 | "prompt": "A photo of a TOK on the beach", 199 | "num_inference_steps": 50, 200 | "replicate_weights": "https://storage.googleapis.com/dan-scratch-public/sdxl/monstertoy_model.tar", 201 | "seed": SEED, 202 | } 203 | lora_b_img = inference_func(lora_b_data) 204 | assert not image_equal_fuzzy(lora_a_img_1, lora_b_img, test_name=request.node.name) 205 | assert not image_equal_fuzzy(base_img_1, lora_b_img, test_name=request.node.name) 206 | -------------------------------------------------------------------------------- /tests/test_remote_train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pytest 3 | import replicate 4 | 5 | 6 | @pytest.fixture(scope="module") 7 | def model_name(request): 8 | return "stability-ai/sdxl" 9 | 10 | 11 | @pytest.fixture(scope="module") 12 | def model(model_name): 13 | return replicate.models.get(model_name) 14 | 15 | 16 | @pytest.fixture(scope="module") 17 | def version(model): 18 | versions = model.versions.list() 19 | return versions[0] 20 | 21 | 22 | @pytest.fixture(scope="module") 23 | def training(model_name, version): 24 | training_input = { 25 | "input_images": "https://storage.googleapis.com/replicate-datasets/sdxl-test/monstertoy-captions.tar" 26 | } 27 | print(f"Training on {model_name}:{version.id}") 28 | return replicate.trainings.create( 29 | version=model_name + ":" + version.id, 30 | input=training_input, 31 | destination="replicate-internal/training-scratch", 32 | ) 33 | 34 | 35 | @pytest.fixture(scope="module") 36 | def prediction_tests(): 37 | return [ 38 | { 39 | "prompt": "A photo of TOK at the beach", 40 | "refine": "expert_ensemble_refiner", 41 | }, 42 | ] 43 | 44 | 45 | def test_training(training): 46 | while training.completed_at is None: 47 | time.sleep(60) 48 | training.reload() 49 | assert training.status == "succeeded" 50 | 51 | 52 | @pytest.fixture(scope="module") 53 | def trained_model_and_version(training): 54 | trained_model, trained_version = training.output["version"].split(":") 55 | return trained_model, trained_version 56 | 57 | 58 | def test_post_training_predictions(trained_model_and_version, prediction_tests): 59 | trained_model, trained_version = trained_model_and_version 60 | model = replicate.models.get(trained_model) 61 | version = model.versions.get(trained_version) 62 | predictions = [ 63 | replicate.predictions.create(version=version, input=val) 64 | for val in prediction_tests 65 | ] 66 | 67 | for val in predictions: 68 | val.wait() 69 | assert val.status == "succeeded" 70 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import requests 4 | import time 5 | from threading import Thread, Lock 6 | import re 7 | import multiprocessing 8 | import subprocess 9 | 10 | ERROR_PATTERN = re.compile(r"ERROR:") 11 | 12 | 13 | def get_image_name(): 14 | current_dir = os.path.basename(os.getcwd()) 15 | 16 | if "cog" in current_dir: 17 | return current_dir 18 | else: 19 | return f"cog-{current_dir}" 20 | 21 | 22 | def process_log_line(line): 23 | line = line.decode("utf-8").strip() 24 | try: 25 | log_data = json.loads(line) 26 | return json.dumps(log_data, indent=2) 27 | except json.JSONDecodeError: 28 | return line 29 | 30 | 31 | def capture_output(pipe, print_lock, logs=None, error_detected=None): 32 | for line in iter(pipe.readline, b""): 33 | formatted_line = process_log_line(line) 34 | with print_lock: 35 | print(formatted_line) 36 | if logs is not None: 37 | logs.append(formatted_line) 38 | if error_detected is not None: 39 | if ERROR_PATTERN.search(formatted_line): 40 | error_detected[0] = True 41 | 42 | 43 | def wait_for_server_to_be_ready(url, timeout=300): 44 | """ 45 | Waits for the server to be ready. 46 | 47 | Args: 48 | - url: The health check URL to poll. 49 | - timeout: Maximum time (in seconds) to wait for the server to be ready. 50 | """ 51 | start_time = time.time() 52 | while True: 53 | try: 54 | response = requests.get(url) 55 | data = response.json() 56 | 57 | if data["status"] == "READY": 58 | return 59 | elif data["status"] == "SETUP_FAILED": 60 | raise RuntimeError( 61 | "Server initialization failed with status: SETUP_FAILED" 62 | ) 63 | 64 | except requests.RequestException: 65 | pass 66 | 67 | if time.time() - start_time > timeout: 68 | raise TimeoutError("Server did not become ready in the expected time.") 69 | 70 | time.sleep(5) # Poll every 5 seconds 71 | 72 | 73 | def run_training_subprocess(command): 74 | # Start the subprocess with pipes for stdout and stderr 75 | process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 76 | 77 | # Create a lock for printing and a list to accumulate logs 78 | print_lock = multiprocessing.Lock() 79 | logs = multiprocessing.Manager().list() 80 | error_detected = multiprocessing.Manager().list([False]) 81 | 82 | # Start two separate processes to handle stdout and stderr 83 | stdout_processor = multiprocessing.Process( 84 | target=capture_output, args=(process.stdout, print_lock, logs, error_detected) 85 | ) 86 | stderr_processor = multiprocessing.Process( 87 | target=capture_output, args=(process.stderr, print_lock, logs, error_detected) 88 | ) 89 | 90 | # Start the log processors 91 | stdout_processor.start() 92 | stderr_processor.start() 93 | 94 | # Wait for the subprocess to finish 95 | process.wait() 96 | 97 | # Wait for the log processors to finish 98 | stdout_processor.join() 99 | stderr_processor.join() 100 | 101 | # Check if an error pattern was detected 102 | if error_detected[0]: 103 | raise Exception("Error detected in training logs! Check logs for details") 104 | 105 | return list(logs) 106 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import tarfile 4 | 5 | from cog import BaseModel, Input, Path 6 | 7 | from predict import SDXL_MODEL_CACHE, SDXL_URL, download_weights 8 | from preprocess import preprocess 9 | from trainer_pti import main 10 | 11 | """ 12 | Wrapper around actual trainer. 13 | """ 14 | OUTPUT_DIR = "training_out" 15 | 16 | 17 | class TrainingOutput(BaseModel): 18 | weights: Path 19 | 20 | 21 | from typing import Tuple 22 | 23 | 24 | def train( 25 | input_images: Path = Input( 26 | description="A .zip or .tar file containing the image files that will be used for fine-tuning" 27 | ), 28 | seed: int = Input( 29 | description="Random seed for reproducible training. Leave empty to use a random seed", 30 | default=None, 31 | ), 32 | resolution: int = Input( 33 | description="Square pixel resolution which your images will be resized to for training", 34 | default=768, 35 | ), 36 | train_batch_size: int = Input( 37 | description="Batch size (per device) for training", 38 | default=4, 39 | ), 40 | num_train_epochs: int = Input( 41 | description="Number of epochs to loop through your training dataset", 42 | default=4000, 43 | ), 44 | max_train_steps: int = Input( 45 | description="Number of individual training steps. Takes precedence over num_train_epochs", 46 | default=1000, 47 | ), 48 | # gradient_accumulation_steps: int = Input( 49 | # description="Number of training steps to accumulate before a backward pass. Effective batch size = gradient_accumulation_steps * batch_size", 50 | # default=1, 51 | # ), # todo. 52 | is_lora: bool = Input( 53 | description="Whether to use LoRA training. If set to False, will use Full fine tuning", 54 | default=True, 55 | ), 56 | unet_learning_rate: float = Input( 57 | description="Learning rate for the U-Net. We recommend this value to be somewhere between `1e-6` to `1e-5`.", 58 | default=1e-6, 59 | ), 60 | ti_lr: float = Input( 61 | description="Scaling of learning rate for training textual inversion embeddings. Don't alter unless you know what you're doing.", 62 | default=3e-4, 63 | ), 64 | lora_lr: float = Input( 65 | description="Scaling of learning rate for training LoRA embeddings. Don't alter unless you know what you're doing.", 66 | default=1e-4, 67 | ), 68 | lora_rank: int = Input( 69 | description="Rank of LoRA embeddings. Don't alter unless you know what you're doing.", 70 | default=32, 71 | ), 72 | lr_scheduler: str = Input( 73 | description="Learning rate scheduler to use for training", 74 | default="constant", 75 | choices=[ 76 | "constant", 77 | "linear", 78 | ], 79 | ), 80 | lr_warmup_steps: int = Input( 81 | description="Number of warmup steps for lr schedulers with warmups.", 82 | default=100, 83 | ), 84 | token_string: str = Input( 85 | description="A unique string that will be trained to refer to the concept in the input images. Can be anything, but TOK works well", 86 | default="TOK", 87 | ), 88 | # token_map: str = Input( 89 | # description="String of token and their impact size specificing tokens used in the dataset. This will be in format of `token1:size1,token2:size2,...`.", 90 | # default="TOK:2", 91 | # ), 92 | caption_prefix: str = Input( 93 | description="Text which will be used as prefix during automatic captioning. Must contain the `token_string`. For example, if caption text is 'a photo of TOK', automatic captioning will expand to 'a photo of TOK under a bridge', 'a photo of TOK holding a cup', etc.", 94 | default="a photo of TOK, ", 95 | ), 96 | mask_target_prompts: str = Input( 97 | description="Prompt that describes part of the image that you will find important. For example, if you are fine-tuning your pet, `photo of a dog` will be a good prompt. Prompt-based masking is used to focus the fine-tuning process on the important/salient parts of the image", 98 | default=None, 99 | ), 100 | crop_based_on_salience: bool = Input( 101 | description="If you want to crop the image to `target_size` based on the important parts of the image, set this to True. If you want to crop the image based on face detection, set this to False", 102 | default=True, 103 | ), 104 | use_face_detection_instead: bool = Input( 105 | description="If you want to use face detection instead of CLIPSeg for masking. For face applications, we recommend using this option.", 106 | default=False, 107 | ), 108 | clipseg_temperature: float = Input( 109 | description="How blurry you want the CLIPSeg mask to be. We recommend this value be something between `0.5` to `1.0`. If you want to have more sharp mask (but thus more errorful), you can decrease this value.", 110 | default=1.0, 111 | ), 112 | verbose: bool = Input(description="verbose output", default=True), 113 | checkpointing_steps: int = Input( 114 | description="Number of steps between saving checkpoints. Set to very very high number to disable checkpointing, because you don't need one.", 115 | default=999999, 116 | ), 117 | input_images_filetype: str = Input( 118 | description="Filetype of the input images. Can be either `zip` or `tar`. By default its `infer`, and it will be inferred from the ext of input file.", 119 | default="infer", 120 | choices=["zip", "tar", "infer"], 121 | ), 122 | ) -> TrainingOutput: 123 | # Hard-code token_map for now. Make it configurable once we support multiple concepts or user-uploaded caption csv. 124 | token_map = token_string + ":2" 125 | 126 | # Process 'token_to_train' and 'input_data_tar_or_zip' 127 | inserting_list_tokens = token_map.split(",") 128 | 129 | token_dict = {} 130 | running_tok_cnt = 0 131 | all_token_lists = [] 132 | for token in inserting_list_tokens: 133 | n_tok = int(token.split(":")[1]) 134 | 135 | token_dict[token.split(":")[0]] = "".join( 136 | [f"" for i in range(n_tok)] 137 | ) 138 | all_token_lists.extend([f"" for i in range(n_tok)]) 139 | 140 | running_tok_cnt += n_tok 141 | 142 | input_dir = preprocess( 143 | input_images_filetype=input_images_filetype, 144 | input_zip_path=input_images, 145 | caption_text=caption_prefix, 146 | mask_target_prompts=mask_target_prompts, 147 | target_size=resolution, 148 | crop_based_on_salience=crop_based_on_salience, 149 | use_face_detection_instead=use_face_detection_instead, 150 | temp=clipseg_temperature, 151 | substitution_tokens=list(token_dict.keys()), 152 | ) 153 | 154 | if not os.path.exists(SDXL_MODEL_CACHE): 155 | download_weights(SDXL_URL, SDXL_MODEL_CACHE) 156 | if os.path.exists(OUTPUT_DIR): 157 | shutil.rmtree(OUTPUT_DIR) 158 | os.makedirs(OUTPUT_DIR) 159 | 160 | main( 161 | pretrained_model_name_or_path=SDXL_MODEL_CACHE, 162 | instance_data_dir=os.path.join(input_dir, "captions.csv"), 163 | output_dir=OUTPUT_DIR, 164 | seed=seed, 165 | resolution=resolution, 166 | train_batch_size=train_batch_size, 167 | num_train_epochs=num_train_epochs, 168 | max_train_steps=max_train_steps, 169 | gradient_accumulation_steps=1, 170 | unet_learning_rate=unet_learning_rate, 171 | ti_lr=ti_lr, 172 | lora_lr=lora_lr, 173 | lr_scheduler=lr_scheduler, 174 | lr_warmup_steps=lr_warmup_steps, 175 | token_dict=token_dict, 176 | inserting_list_tokens=all_token_lists, 177 | verbose=verbose, 178 | checkpointing_steps=checkpointing_steps, 179 | scale_lr=False, 180 | max_grad_norm=1.0, 181 | allow_tf32=True, 182 | mixed_precision="bf16", 183 | device="cuda:0", 184 | lora_rank=lora_rank, 185 | is_lora=is_lora, 186 | ) 187 | 188 | directory = Path(OUTPUT_DIR) 189 | out_path = "trained_model.tar" 190 | 191 | with tarfile.open(out_path, "w") as tar: 192 | for file_path in directory.rglob("*"): 193 | print(file_path) 194 | arcname = file_path.relative_to(directory) 195 | tar.add(file_path, arcname=arcname) 196 | 197 | return TrainingOutput(weights=Path(out_path)) 198 | -------------------------------------------------------------------------------- /trainer_pti.py: -------------------------------------------------------------------------------- 1 | # Bootstrapped from Huggingface diffuser's code. 2 | import fnmatch 3 | import json 4 | import math 5 | import os 6 | import shutil 7 | from typing import List, Optional 8 | 9 | import numpy as np 10 | import torch 11 | import torch.utils.checkpoint 12 | from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0 13 | from diffusers.optimization import get_scheduler 14 | from safetensors.torch import save_file 15 | from tqdm.auto import tqdm 16 | 17 | from dataset_and_utils import ( 18 | PreprocessedDataset, 19 | TokenEmbeddingsHandler, 20 | load_models, 21 | unet_attn_processors_state_dict, 22 | ) 23 | 24 | 25 | def main( 26 | pretrained_model_name_or_path: Optional[ 27 | str 28 | ] = "./cache", # "stabilityai/stable-diffusion-xl-base-1.0", 29 | revision: Optional[str] = None, 30 | instance_data_dir: Optional[str] = "./dataset/zeke/captions.csv", 31 | output_dir: str = "ft_masked_coke", 32 | seed: Optional[int] = 42, 33 | resolution: int = 512, 34 | crops_coords_top_left_h: int = 0, 35 | crops_coords_top_left_w: int = 0, 36 | train_batch_size: int = 1, 37 | do_cache: bool = True, 38 | num_train_epochs: int = 600, 39 | max_train_steps: Optional[int] = None, 40 | checkpointing_steps: int = 500000, # default to no checkpoints 41 | gradient_accumulation_steps: int = 1, # todo 42 | unet_learning_rate: float = 1e-5, 43 | ti_lr: float = 3e-4, 44 | lora_lr: float = 1e-4, 45 | pivot_halfway: bool = True, 46 | scale_lr: bool = False, 47 | lr_scheduler: str = "constant", 48 | lr_warmup_steps: int = 500, 49 | lr_num_cycles: int = 1, 50 | lr_power: float = 1.0, 51 | dataloader_num_workers: int = 0, 52 | max_grad_norm: float = 1.0, # todo with tests 53 | allow_tf32: bool = True, 54 | mixed_precision: Optional[str] = "bf16", 55 | device: str = "cuda:0", 56 | token_dict: dict = {"TOKEN": ""}, 57 | inserting_list_tokens: List[str] = [""], 58 | verbose: bool = True, 59 | is_lora: bool = True, 60 | lora_rank: int = 32, 61 | ) -> None: 62 | if allow_tf32: 63 | torch.backends.cuda.matmul.allow_tf32 = True 64 | if not seed: 65 | seed = np.random.randint(0, 2**32 - 1) 66 | print("Using seed", seed) 67 | torch.manual_seed(seed) 68 | 69 | weight_dtype = torch.float32 70 | if mixed_precision == "fp16": 71 | weight_dtype = torch.float16 72 | elif mixed_precision == "bf16": 73 | weight_dtype = torch.bfloat16 74 | 75 | if scale_lr: 76 | unet_learning_rate = ( 77 | unet_learning_rate * gradient_accumulation_steps * train_batch_size 78 | ) 79 | 80 | ( 81 | tokenizer_one, 82 | tokenizer_two, 83 | noise_scheduler, 84 | text_encoder_one, 85 | text_encoder_two, 86 | vae, 87 | unet, 88 | ) = load_models(pretrained_model_name_or_path, revision, device, weight_dtype) 89 | 90 | print("# PTI : Loaded models") 91 | 92 | # Initialize new tokens for training. 93 | 94 | embedding_handler = TokenEmbeddingsHandler( 95 | [text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two] 96 | ) 97 | embedding_handler.initialize_new_tokens(inserting_toks=inserting_list_tokens) 98 | 99 | text_encoders = [text_encoder_one, text_encoder_two] 100 | 101 | unet_param_to_optimize = [] 102 | # fine tune only attn weights 103 | 104 | text_encoder_parameters = [] 105 | for text_encoder in text_encoders: 106 | for name, param in text_encoder.named_parameters(): 107 | if "token_embedding" in name: 108 | param.requires_grad = True 109 | print(name) 110 | text_encoder_parameters.append(param) 111 | else: 112 | param.requires_grad = False 113 | 114 | if not is_lora: 115 | WHITELIST_PATTERNS = [ 116 | # "*.attn*.weight", 117 | # "*ff*.weight", 118 | "*" 119 | ] # TODO : make this a parameter 120 | BLACKLIST_PATTERNS = ["*.norm*.weight", "*time*"] 121 | 122 | unet_param_to_optimize_names = [] 123 | for name, param in unet.named_parameters(): 124 | if any( 125 | fnmatch.fnmatch(name, pattern) for pattern in WHITELIST_PATTERNS 126 | ) and not any( 127 | fnmatch.fnmatch(name, pattern) for pattern in BLACKLIST_PATTERNS 128 | ): 129 | param.requires_grad_(True) 130 | unet_param_to_optimize_names.append(name) 131 | print(f"Training: {name}") 132 | else: 133 | param.requires_grad_(False) 134 | 135 | # Optimizer creation 136 | params_to_optimize = [ 137 | { 138 | "params": unet_param_to_optimize, 139 | "lr": unet_learning_rate, 140 | }, 141 | { 142 | "params": text_encoder_parameters, 143 | "lr": ti_lr, 144 | "weight_decay": 1e-3, 145 | }, 146 | ] 147 | 148 | else: 149 | # Do lora-training instead. 150 | unet.requires_grad_(False) 151 | unet_lora_attn_procs = {} 152 | unet_lora_parameters = [] 153 | for name, attn_processor in unet.attn_processors.items(): 154 | cross_attention_dim = ( 155 | None 156 | if name.endswith("attn1.processor") 157 | else unet.config.cross_attention_dim 158 | ) 159 | if name.startswith("mid_block"): 160 | hidden_size = unet.config.block_out_channels[-1] 161 | elif name.startswith("up_blocks"): 162 | block_id = int(name[len("up_blocks.")]) 163 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 164 | elif name.startswith("down_blocks"): 165 | block_id = int(name[len("down_blocks.")]) 166 | hidden_size = unet.config.block_out_channels[block_id] 167 | 168 | module = LoRAAttnProcessor2_0( 169 | hidden_size=hidden_size, 170 | cross_attention_dim=cross_attention_dim, 171 | rank=lora_rank, 172 | ) 173 | unet_lora_attn_procs[name] = module 174 | module.to(device) 175 | unet_lora_parameters.extend(module.parameters()) 176 | 177 | unet.set_attn_processor(unet_lora_attn_procs) 178 | 179 | params_to_optimize = [ 180 | { 181 | "params": unet_lora_parameters, 182 | "lr": lora_lr, 183 | }, 184 | { 185 | "params": text_encoder_parameters, 186 | "lr": ti_lr, 187 | "weight_decay": 1e-3, 188 | }, 189 | ] 190 | 191 | optimizer = torch.optim.AdamW( 192 | params_to_optimize, 193 | weight_decay=1e-4, 194 | ) 195 | 196 | print(f"# PTI : Loading dataset, do_cache {do_cache}") 197 | 198 | train_dataset = PreprocessedDataset( 199 | instance_data_dir, 200 | tokenizer_one, 201 | tokenizer_two, 202 | vae.float(), 203 | do_cache=True, 204 | substitute_caption_map=token_dict, 205 | ) 206 | 207 | print("# PTI : Loaded dataset") 208 | 209 | train_dataloader = torch.utils.data.DataLoader( 210 | train_dataset, 211 | batch_size=train_batch_size, 212 | shuffle=True, 213 | num_workers=dataloader_num_workers, 214 | ) 215 | 216 | num_update_steps_per_epoch = math.ceil( 217 | len(train_dataloader) / gradient_accumulation_steps 218 | ) 219 | if max_train_steps is None: 220 | max_train_steps = num_train_epochs * num_update_steps_per_epoch 221 | 222 | lr_scheduler = get_scheduler( 223 | lr_scheduler, 224 | optimizer=optimizer, 225 | num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps, 226 | num_training_steps=max_train_steps * gradient_accumulation_steps, 227 | num_cycles=lr_num_cycles, 228 | power=lr_power, 229 | ) 230 | 231 | num_update_steps_per_epoch = math.ceil( 232 | len(train_dataloader) / gradient_accumulation_steps 233 | ) 234 | num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch) 235 | 236 | total_batch_size = train_batch_size * gradient_accumulation_steps 237 | 238 | if verbose: 239 | print(f"# PTI : Running training ") 240 | print(f"# PTI : Num examples = {len(train_dataset)}") 241 | print(f"# PTI : Num batches each epoch = {len(train_dataloader)}") 242 | print(f"# PTI : Num Epochs = {num_train_epochs}") 243 | print(f"# PTI : Instantaneous batch size per device = {train_batch_size}") 244 | print( 245 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" 246 | ) 247 | print(f"# PTI : Gradient Accumulation steps = {gradient_accumulation_steps}") 248 | print(f"# PTI : Total optimization steps = {max_train_steps}") 249 | 250 | global_step = 0 251 | first_epoch = 0 252 | 253 | # Only show the progress bar once on each machine. 254 | progress_bar = tqdm(range(global_step, max_train_steps)) 255 | checkpoint_dir = "checkpoint" 256 | if os.path.exists(checkpoint_dir): 257 | shutil.rmtree(checkpoint_dir) 258 | 259 | os.makedirs(f"{checkpoint_dir}/unet", exist_ok=True) 260 | os.makedirs(f"{checkpoint_dir}/embeddings", exist_ok=True) 261 | 262 | for epoch in range(first_epoch, num_train_epochs): 263 | if pivot_halfway: 264 | if epoch == num_train_epochs // 2: 265 | print("# PTI : Pivot halfway") 266 | # remove text encoder parameters from optimizer 267 | params_to_optimize = params_to_optimize[:1] 268 | optimizer = torch.optim.AdamW( 269 | params_to_optimize, 270 | weight_decay=1e-4, 271 | ) 272 | 273 | unet.train() 274 | for step, batch in enumerate(train_dataloader): 275 | progress_bar.update(1) 276 | progress_bar.set_description(f"# PTI :step: {global_step}, epoch: {epoch}") 277 | global_step += 1 278 | 279 | (tok1, tok2), vae_latent, mask = batch 280 | vae_latent = vae_latent.to(weight_dtype) 281 | 282 | # tokens to text embeds 283 | prompt_embeds_list = [] 284 | for tok, text_encoder in zip((tok1, tok2), text_encoders): 285 | prompt_embeds_out = text_encoder( 286 | tok.to(text_encoder.device), 287 | output_hidden_states=True, 288 | ) 289 | 290 | pooled_prompt_embeds = prompt_embeds_out[0] 291 | prompt_embeds = prompt_embeds_out.hidden_states[-2] 292 | bs_embed, seq_len, _ = prompt_embeds.shape 293 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) 294 | prompt_embeds_list.append(prompt_embeds) 295 | 296 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) 297 | pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) 298 | 299 | # Create Spatial-dimensional conditions. 300 | 301 | original_size = (resolution, resolution) 302 | target_size = (resolution, resolution) 303 | crops_coords_top_left = (crops_coords_top_left_h, crops_coords_top_left_w) 304 | add_time_ids = list(original_size + crops_coords_top_left + target_size) 305 | add_time_ids = torch.tensor([add_time_ids]) 306 | 307 | add_time_ids = add_time_ids.to(device, dtype=prompt_embeds.dtype).repeat( 308 | bs_embed, 1 309 | ) 310 | 311 | added_kw = {"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids} 312 | 313 | # Sample noise that we'll add to the latents 314 | noise = torch.randn_like(vae_latent) 315 | bsz = vae_latent.shape[0] 316 | 317 | timesteps = torch.randint( 318 | 0, 319 | noise_scheduler.config.num_train_timesteps, 320 | (bsz,), 321 | device=vae_latent.device, 322 | ) 323 | timesteps = timesteps.long() 324 | 325 | noisy_model_input = noise_scheduler.add_noise(vae_latent, noise, timesteps) 326 | 327 | # Predict the noise residual 328 | model_pred = unet( 329 | noisy_model_input, 330 | timesteps, 331 | prompt_embeds, 332 | added_cond_kwargs=added_kw, 333 | ).sample 334 | 335 | loss = (model_pred - noise).pow(2) * mask 336 | loss = loss.mean() 337 | 338 | loss.backward() 339 | optimizer.step() 340 | lr_scheduler.step() 341 | optimizer.zero_grad() 342 | 343 | # every step, we reset the embeddings to the original embeddings. 344 | 345 | for idx, text_encoder in enumerate(text_encoders): 346 | embedding_handler.retract_embeddings() 347 | 348 | if global_step % checkpointing_steps == 0: 349 | # save the required params of unet with safetensor 350 | 351 | if not is_lora: 352 | tensors = { 353 | name: param 354 | for name, param in unet.named_parameters() 355 | if name in unet_param_to_optimize_names 356 | } 357 | save_file( 358 | tensors, 359 | f"{checkpoint_dir}/unet/checkpoint-{global_step}.unet.safetensors", 360 | ) 361 | 362 | else: 363 | lora_tensors = unet_attn_processors_state_dict(unet) 364 | 365 | save_file( 366 | lora_tensors, 367 | f"{checkpoint_dir}/unet/checkpoint-{global_step}.lora.safetensors", 368 | ) 369 | 370 | embedding_handler.save_embeddings( 371 | f"{checkpoint_dir}/embeddings/checkpoint-{global_step}.pti", 372 | ) 373 | 374 | # final_save 375 | print("Saving final model for return") 376 | if not is_lora: 377 | tensors = { 378 | name: param 379 | for name, param in unet.named_parameters() 380 | if name in unet_param_to_optimize_names 381 | } 382 | save_file( 383 | tensors, 384 | f"{output_dir}/unet.safetensors", 385 | ) 386 | else: 387 | lora_tensors = unet_attn_processors_state_dict(unet) 388 | save_file( 389 | lora_tensors, 390 | f"{output_dir}/lora.safetensors", 391 | ) 392 | 393 | embedding_handler.save_embeddings( 394 | f"{output_dir}/embeddings.pti", 395 | ) 396 | 397 | to_save = token_dict 398 | with open(f"{output_dir}/special_params.json", "w") as f: 399 | json.dump(to_save, f) 400 | 401 | 402 | if __name__ == "__main__": 403 | main() 404 | -------------------------------------------------------------------------------- /weights.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | import hashlib 3 | import os 4 | import shutil 5 | import subprocess 6 | import time 7 | 8 | 9 | class WeightsDownloadCache: 10 | def __init__( 11 | self, min_disk_free: int = 10 * (2**30), base_dir: str = "/src/weights-cache" 12 | ): 13 | """ 14 | WeightsDownloadCache is meant to track and download weights files as fast 15 | as possible, while ensuring there's enough disk space. 16 | 17 | It tries to keep the most recently used weights files in the cache, so 18 | ensure you call ensure() on the weights each time you use them. 19 | 20 | It will not re-download weights files that are already in the cache. 21 | 22 | :param min_disk_free: Minimum disk space required to start download, in bytes. 23 | :param base_dir: The base directory to store weights files. 24 | """ 25 | self.min_disk_free = min_disk_free 26 | self.base_dir = base_dir 27 | self._hits = 0 28 | self._misses = 0 29 | 30 | # Least Recently Used (LRU) cache for paths 31 | self.lru_paths = deque() 32 | if not os.path.exists(base_dir): 33 | os.makedirs(base_dir) 34 | 35 | def _remove_least_recent(self) -> None: 36 | """ 37 | Remove the least recently used weights file from the cache and disk. 38 | """ 39 | oldest = self.lru_paths.popleft() 40 | self._rm_disk(oldest) 41 | 42 | def cache_info(self) -> str: 43 | """ 44 | Get cache information. 45 | 46 | :return: Cache information. 47 | """ 48 | 49 | return f"CacheInfo(hits={self._hits}, misses={self._misses}, base_dir='{self.base_dir}', currsize={len(self.lru_paths)})" 50 | 51 | def _rm_disk(self, path: str) -> None: 52 | """ 53 | Remove a weights file or directory from disk. 54 | :param path: Path to remove. 55 | """ 56 | if os.path.isfile(path): 57 | os.remove(path) 58 | elif os.path.isdir(path): 59 | shutil.rmtree(path) 60 | 61 | def _has_enough_space(self) -> bool: 62 | """ 63 | Check if there's enough disk space. 64 | 65 | :return: True if there's more than min_disk_free free, False otherwise. 66 | """ 67 | disk_usage = shutil.disk_usage(self.base_dir) 68 | print(f"Free disk space: {disk_usage.free}") 69 | return disk_usage.free >= self.min_disk_free 70 | 71 | def ensure(self, url: str) -> str: 72 | """ 73 | Ensure weights file is in the cache and return its path. 74 | 75 | This also updates the LRU cache to mark the weights as recently used. 76 | 77 | :param url: URL to download weights file from, if not in cache. 78 | :return: Path to weights. 79 | """ 80 | path = self.weights_path(url) 81 | 82 | if path in self.lru_paths: 83 | # here we remove to re-add to the end of the LRU (marking it as recently used) 84 | self._hits += 1 85 | self.lru_paths.remove(path) 86 | else: 87 | self._misses += 1 88 | self.download_weights(url, path) 89 | 90 | self.lru_paths.append(path) # Add file to end of cache 91 | return path 92 | 93 | def weights_path(self, url: str) -> str: 94 | """ 95 | Generate path to store a weights file based hash of the URL. 96 | 97 | :param url: URL to download weights file from. 98 | :return: Path to store weights file. 99 | """ 100 | hashed_url = hashlib.sha256(url.encode()).hexdigest() 101 | short_hash = hashed_url[:16] # Use the first 16 characters of the hash 102 | return os.path.join(self.base_dir, short_hash) 103 | 104 | def download_weights(self, url: str, dest: str) -> None: 105 | """ 106 | Download weights file from a URL, ensuring there's enough disk space. 107 | 108 | :param url: URL to download weights file from. 109 | :param dest: Path to store weights file. 110 | """ 111 | print("Ensuring enough disk space...") 112 | while not self._has_enough_space() and len(self.lru_paths) > 0: 113 | self._remove_least_recent() 114 | 115 | print(f"Downloading weights: {url}") 116 | 117 | st = time.time() 118 | # maybe retry with the real url if this doesn't work 119 | try: 120 | output = subprocess.check_output(["pget", "-x", url, dest], close_fds=True) 121 | print(output) 122 | except subprocess.CalledProcessError as e: 123 | # If download fails, clean up and re-raise exception 124 | print(e.output) 125 | self._rm_disk(dest) 126 | raise e 127 | print(f"Downloaded weights in {time.time() - st} seconds") 128 | --------------------------------------------------------------------------------