├── .github └── FUNDING.yml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── bin └── leap_textual_inversion ├── leap_sd ├── __init__.py ├── module.py └── utils.py ├── setup.py └── training ├── README.md ├── dataset_creator ├── bip39.txt └── sd_extractor.py ├── get_extrema.py ├── lora_dataset_creator ├── create_dataset.py ├── lora_words.txt ├── split_data.py └── train_loras.py ├── train.py └── train_lora.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: emerald_show 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.egg-info 3 | training/dataset_creator/sd_extracted 4 | training/lora_dataset_creator/lora_dataset 5 | lightning_logs/ 6 | /LEAP 7 | /wandb 8 | .venv/ 9 | *.ckpt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Peter Willemsen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include leap_sd/model.ckpt -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # What is LEAP? 2 | 3 | [Demo video](https://www.youtube.com/watch?v=iv_P6db88ts) 4 | 5 | It's a research project where input images are being converted to a local minimum in latent space. Then, we feed the weights to Stable Diffusion's Textual Inversion model. 6 | 7 | The benefits are huge, training takes easily under 5 minutes, with little quality difference from training for hours on the same hardware. 8 | 9 | It is scalable enough to offer in a low-scale Discord bot like Thingy, where our goal is to introduce people to AI without it costing hundreds of dollars per month in GPU rent! 10 | 11 | Love you all! Sorry for the fact this README is a little crunchy. It is because I'm so excited and jumpy, never thought it would work. 12 | 13 | [Join my discord](https://discord.gg/j4wQYhhvVd) to check out Thingy 3! It has `/train` that uses LEAP under the hood! 14 | 15 | [Check this colab](https://colab.research.google.com/drive/1-uBBQpPlt4k5YDNZiN4H4ICWlkVcitfP?usp=sharing) for testing out right away! 16 | 17 | # How to use with Stable Diffusion 18 | 19 | **Note** The author is used to Linux, while Windows should work, the author can't guarantee working README instructions. 20 | 21 | - Run the following command: `pip install git+https://github.com/peterwilli/sd-leap-booster.git` 22 | - Download the weights (for example, [Stable Diffusion 2.1 with Textual Inversion](https://github.com/peterwilli/sd-leap-booster/releases/download/sd-2.1-ti/leap_ti_2.0_sd2.1_beta.ckpt)) 23 | - Run `leap_textual_inversion` and set the parameters to what you wish (they are similar to the [official textual inversion script](https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py)) and also point to your LEAP model weights: `--leap_model_path=/path/to/leap_ti_2.0_beta.ckpt` 24 | - An example: `leap_textual_inversion --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-1-base --placeholder_token="" --train_data_dir=path/to/images --learning_rate=0.001 --leap_model_path=/path/to/leap_ti_2.0_beta.ckpt` 25 | 26 | # Train your own model! 27 | 28 | See [training/README.md](training/README.md) for instructions 29 | 30 | # Support, sponsorship and thanks 31 | 32 | Are you looking to make a positive impact and get some awesome perks in the process? **[Join me on Patreon!](https://www.patreon.com/emerald_show)** For just $3 per month, you can join our Patreon community and help a creative mind in the Netherlands bring their ideas to life. 33 | 34 | Not only will you get the satisfaction of supporting an individual's passions, but you'll also receive a 50% discount on any paid services that result from the projects you sponsor. Plus, as a Patreon member, you'll have exclusive voting rights on new features and the opportunity to shape the direction of future projects. Don't miss out on this chance to make a difference and get some amazing benefits in return. 35 | 36 | One of the things we intend on doing, is trying to make LEAP with Lora! 37 | 38 | **Special thanks to:** 39 | 40 | - [Mahdi Chaker](https://twitter.com/MahdiMC) for the heavy training GPUs! 41 | - LAION/Stability AI for providing training GPU's ~~And hopefully I get confident enough to soon try them.~~ 42 | - Jina.ai for giving the computation power to run my bot! 43 | - You? 44 | -------------------------------------------------------------------------------- /bin/leap_textual_inversion: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | #Bootstrapped from: https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py 3 | import argparse 4 | import logging 5 | import math 6 | import os 7 | import random 8 | from pathlib import Path 9 | from typing import Optional 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | import torch.utils.checkpoint 15 | from torch.utils.data import Dataset 16 | 17 | import datasets 18 | import diffusers 19 | import PIL 20 | import transformers 21 | from accelerate import Accelerator 22 | from accelerate.logging import get_logger 23 | from accelerate.utils import set_seed 24 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel 25 | from diffusers.optimization import get_scheduler 26 | from diffusers.utils import check_min_version 27 | from diffusers.utils.import_utils import is_xformers_available 28 | from huggingface_hub import HfFolder, Repository, whoami 29 | 30 | # TODO: remove and import from diffusers.utils when the new version of diffusers is released 31 | from packaging import version 32 | from PIL import Image, ImageOps 33 | from torchvision import transforms 34 | import torchvision 35 | from tqdm.auto import tqdm 36 | from transformers import CLIPTextModel, CLIPTokenizer 37 | import leap_sd 38 | from imgaug import augmenters as iaa 39 | 40 | 41 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): 42 | PIL_INTERPOLATION = { 43 | "linear": PIL.Image.Resampling.BILINEAR, 44 | "bilinear": PIL.Image.Resampling.BILINEAR, 45 | "bicubic": PIL.Image.Resampling.BICUBIC, 46 | "lanczos": PIL.Image.Resampling.LANCZOS, 47 | "nearest": PIL.Image.Resampling.NEAREST, 48 | } 49 | else: 50 | PIL_INTERPOLATION = { 51 | "linear": PIL.Image.LINEAR, 52 | "bilinear": PIL.Image.BILINEAR, 53 | "bicubic": PIL.Image.BICUBIC, 54 | "lanczos": PIL.Image.LANCZOS, 55 | "nearest": PIL.Image.NEAREST, 56 | } 57 | # ------------------------------------------------------------------------------ 58 | 59 | 60 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 61 | check_min_version("0.10.0.dev0") 62 | 63 | 64 | logger = get_logger(__name__) 65 | 66 | 67 | def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path): 68 | logger.info("Saving embeddings") 69 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id] 70 | learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} 71 | torch.save(learned_embeds_dict, save_path) 72 | 73 | 74 | def parse_args(): 75 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 76 | parser.add_argument( 77 | "--save_steps", 78 | type=int, 79 | default=500, 80 | help="Save learned_embeds.bin every X updates steps.", 81 | ) 82 | parser.add_argument( 83 | "--only_save_embeds", 84 | action="store_true", 85 | default=False, 86 | help="Save only the embeddings for the new concept.", 87 | ) 88 | parser.add_argument( 89 | "--pretrained_model_name_or_path", 90 | type=str, 91 | default=None, 92 | required=True, 93 | help="Path to pretrained model or model identifier from huggingface.co/models.", 94 | ) 95 | parser.add_argument( 96 | "--revision", 97 | type=str, 98 | default=None, 99 | required=False, 100 | help="Revision of pretrained model identifier from huggingface.co/models.", 101 | ) 102 | parser.add_argument( 103 | "--tokenizer_name", 104 | type=str, 105 | default=None, 106 | help="Pretrained tokenizer name or path if not the same as model_name", 107 | ) 108 | parser.add_argument( 109 | "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data." 110 | ) 111 | parser.add_argument( 112 | "--placeholder_token", 113 | type=str, 114 | default=None, 115 | required=True, 116 | help="A token to use as a placeholder for the concept.", 117 | ) 118 | parser.add_argument( 119 | "--leap_model_path", 120 | type=str, 121 | required=True, 122 | help="The path to the current LEAP model you want to use.", 123 | ) 124 | parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'") 125 | parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") 126 | parser.add_argument( 127 | "--output_dir", 128 | type=str, 129 | default="text-inversion-model", 130 | help="The output directory where the model predictions and checkpoints will be written.", 131 | ) 132 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 133 | parser.add_argument( 134 | "--resolution", 135 | type=int, 136 | default=512, 137 | help=( 138 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 139 | " resolution" 140 | ), 141 | ) 142 | parser.add_argument( 143 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 144 | ) 145 | parser.add_argument( 146 | "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." 147 | ) 148 | parser.add_argument("--num_train_epochs", type=int, default=100) 149 | parser.add_argument( 150 | "--max_train_steps", 151 | type=int, 152 | default=100, 153 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 154 | ) 155 | parser.add_argument( 156 | "--gradient_accumulation_steps", 157 | type=int, 158 | default=4, 159 | help="Number of updates steps to accumulate before performing a backward/update pass.", 160 | ) 161 | parser.add_argument( 162 | "--gradient_checkpointing", 163 | action="store_true", 164 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 165 | ) 166 | parser.add_argument( 167 | "--learning_rate", 168 | type=float, 169 | default=2e-3, 170 | help="Initial learning rate (after the potential warmup period) to use.", 171 | ) 172 | parser.add_argument( 173 | "--scale_lr", 174 | action="store_true", 175 | default=True, 176 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 177 | ) 178 | parser.add_argument( 179 | "--lr_scheduler", 180 | type=str, 181 | default="constant_with_warmup", 182 | help=( 183 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 184 | ' "constant", "constant_with_warmup"]' 185 | ), 186 | ) 187 | parser.add_argument( 188 | "--lr_warmup_steps", type=int, default=20, help="Number of steps for the warmup in the lr scheduler." 189 | ) 190 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 191 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 192 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 193 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 194 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 195 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 196 | parser.add_argument( 197 | "--hub_model_id", 198 | type=str, 199 | default=None, 200 | help="The name of the repository to keep in sync with the local `output_dir`.", 201 | ) 202 | parser.add_argument( 203 | "--logging_dir", 204 | type=str, 205 | default="logs", 206 | help=( 207 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 208 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 209 | ), 210 | ) 211 | parser.add_argument( 212 | "--mixed_precision", 213 | type=str, 214 | default="no", 215 | choices=["no", "fp16", "bf16"], 216 | help=( 217 | "Whether to use mixed precision. Choose" 218 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 219 | "and an Nvidia Ampere GPU." 220 | ), 221 | ) 222 | parser.add_argument( 223 | "--allow_tf32", 224 | action="store_true", 225 | help=( 226 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 227 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 228 | ), 229 | ) 230 | parser.add_argument( 231 | "--report_to", 232 | type=str, 233 | default="tensorboard", 234 | help=( 235 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 236 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 237 | ), 238 | ) 239 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 240 | parser.add_argument( 241 | "--checkpointing_steps", 242 | type=int, 243 | default=500, 244 | help=( 245 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 246 | " training using `--resume_from_checkpoint`." 247 | ), 248 | ) 249 | parser.add_argument( 250 | "--resume_from_checkpoint", 251 | type=str, 252 | default=None, 253 | help=( 254 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 255 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 256 | ), 257 | ) 258 | parser.add_argument( 259 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 260 | ) 261 | 262 | args = parser.parse_args() 263 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 264 | if env_local_rank != -1 and env_local_rank != args.local_rank: 265 | args.local_rank = env_local_rank 266 | 267 | if args.train_data_dir is None: 268 | raise ValueError("You must specify a train data directory.") 269 | 270 | return args 271 | 272 | 273 | imagenet_templates_small = [ 274 | "a photo of a {}", 275 | "a rendering of a {}", 276 | "a cropped photo of the {}", 277 | "the photo of a {}", 278 | "a photo of a clean {}", 279 | "a photo of a dirty {}", 280 | "a dark photo of the {}", 281 | "a photo of my {}", 282 | "a photo of the cool {}", 283 | "a close-up photo of a {}", 284 | "a bright photo of the {}", 285 | "a cropped photo of a {}", 286 | "a photo of the {}", 287 | "a good photo of the {}", 288 | "a photo of one {}", 289 | "a close-up photo of the {}", 290 | "a rendition of the {}", 291 | "a photo of the clean {}", 292 | "a rendition of a {}", 293 | "a photo of a nice {}", 294 | "a good photo of a {}", 295 | "a photo of the nice {}", 296 | "a photo of the small {}", 297 | "a photo of the weird {}", 298 | "a photo of the large {}", 299 | "a photo of a cool {}", 300 | "a photo of a small {}", 301 | ] 302 | 303 | imagenet_style_templates_small = [ 304 | "a painting in the style of {}", 305 | "a rendering in the style of {}", 306 | "a cropped painting in the style of {}", 307 | "the painting in the style of {}", 308 | "a clean painting in the style of {}", 309 | "a dirty painting in the style of {}", 310 | "a dark painting in the style of {}", 311 | "a picture in the style of {}", 312 | "a cool painting in the style of {}", 313 | "a close-up painting in the style of {}", 314 | "a bright painting in the style of {}", 315 | "a cropped painting in the style of {}", 316 | "a good painting in the style of {}", 317 | "a close-up painting in the style of {}", 318 | "a rendition in the style of {}", 319 | "a nice painting in the style of {}", 320 | "a small painting in the style of {}", 321 | "a weird painting in the style of {}", 322 | "a large painting in the style of {}", 323 | ] 324 | 325 | 326 | class TextualInversionDataset(Dataset): 327 | def __init__( 328 | self, 329 | data_root, 330 | tokenizer, 331 | learnable_property="object", # [object, style] 332 | size=512, 333 | repeats=100, 334 | interpolation="bicubic", 335 | flip_p=0.5, 336 | set="train", 337 | placeholder_token="*", 338 | center_crop=False, 339 | ): 340 | self.data_root = data_root 341 | self.tokenizer = tokenizer 342 | self.learnable_property = learnable_property 343 | self.size = size 344 | self.placeholder_token = placeholder_token 345 | self.center_crop = center_crop 346 | self.flip_p = flip_p 347 | 348 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] 349 | 350 | self.num_images = len(self.image_paths) 351 | self._length = self.num_images 352 | 353 | if set == "train": 354 | self._length = self.num_images * repeats 355 | 356 | self.interpolation = { 357 | "linear": PIL_INTERPOLATION["linear"], 358 | "bilinear": PIL_INTERPOLATION["bilinear"], 359 | "bicubic": PIL_INTERPOLATION["bicubic"], 360 | "lanczos": PIL_INTERPOLATION["lanczos"], 361 | }[interpolation] 362 | 363 | self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small 364 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p) 365 | 366 | def __len__(self): 367 | return self._length 368 | 369 | def __getitem__(self, i): 370 | example = {} 371 | image = Image.open(self.image_paths[i % self.num_images]) 372 | image = ImageOps.exif_transpose(image) 373 | 374 | if not image.mode == "RGB": 375 | image = image.convert("RGB") 376 | 377 | placeholder_string = self.placeholder_token 378 | text = random.choice(self.templates).format(placeholder_string) 379 | 380 | example["input_ids"] = self.tokenizer( 381 | text, 382 | padding="max_length", 383 | truncation=True, 384 | max_length=self.tokenizer.model_max_length, 385 | return_tensors="pt", 386 | ).input_ids[0] 387 | 388 | # default to score-sde preprocessing 389 | img = np.array(image).astype(np.uint8) 390 | 391 | if self.center_crop: 392 | crop = min(img.shape[0], img.shape[1]) 393 | h, w, = ( 394 | img.shape[0], 395 | img.shape[1], 396 | ) 397 | img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] 398 | 399 | image = Image.fromarray(img) 400 | image = image.resize((self.size, self.size), resample=self.interpolation) 401 | 402 | image = self.flip_transform(image) 403 | image = np.array(image).astype(np.uint8) 404 | image = (image / 127.5 - 1.0).astype(np.float32) 405 | 406 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) 407 | return example 408 | 409 | 410 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 411 | if token is None: 412 | token = HfFolder.get_token() 413 | if organization is None: 414 | username = whoami(token)["name"] 415 | return f"{username}/{model_id}" 416 | else: 417 | return f"{organization}/{model_id}" 418 | 419 | @torch.no_grad() 420 | def boost_embed(leap, images_folder): 421 | def repeat_array_to_length(arr, length): 422 | while len(arr) < length: 423 | arr = arr * 2 424 | return arr[:length] 425 | 426 | def load_images(images_path): 427 | image_names = os.listdir(images_path) 428 | random.shuffle(image_names) 429 | image_names = repeat_array_to_length(image_names, 4) 430 | images = None 431 | pred_transforms = transforms.Compose( 432 | [ 433 | iaa.Resize({"shorter-side": (128, 256), "longer-side": "keep-aspect-ratio"}).augment_image, 434 | iaa.CropToFixedSize(width=128, height=128).augment_image, 435 | transforms.ToTensor(), 436 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 437 | ] 438 | ) 439 | for image_name in image_names: 440 | image = Image.open(os.path.join(images_path, image_name)).convert("RGB") 441 | image = ImageOps.exif_transpose(image) 442 | image = pred_transforms(np.array(image)).unsqueeze(0) 443 | if images is None: 444 | images = image 445 | else: 446 | images = torch.cat((images, image), 0) 447 | return images 448 | 449 | images = load_images(images_folder) 450 | # Simulate single item batch 451 | images = images.unsqueeze(0) 452 | embed_model = leap(images) 453 | embed_model = embed_model.squeeze() 454 | return embed_model 455 | 456 | def main(): 457 | args = parse_args() 458 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 459 | 460 | accelerator = Accelerator( 461 | gradient_accumulation_steps=args.gradient_accumulation_steps, 462 | mixed_precision=args.mixed_precision, 463 | log_with=args.report_to, 464 | logging_dir=logging_dir, 465 | ) 466 | 467 | # Make one log on every process with the configuration for debugging. 468 | logging.basicConfig( 469 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 470 | datefmt="%m/%d/%Y %H:%M:%S", 471 | level=logging.INFO, 472 | ) 473 | logger.info(accelerator.state, main_process_only=False) 474 | if accelerator.is_local_main_process: 475 | datasets.utils.logging.set_verbosity_warning() 476 | transformers.utils.logging.set_verbosity_warning() 477 | diffusers.utils.logging.set_verbosity_info() 478 | else: 479 | datasets.utils.logging.set_verbosity_error() 480 | transformers.utils.logging.set_verbosity_error() 481 | diffusers.utils.logging.set_verbosity_error() 482 | 483 | # If passed along, set the training seed now. 484 | if args.seed is not None: 485 | set_seed(args.seed) 486 | 487 | # Handle the repository creation 488 | if accelerator.is_main_process: 489 | if args.push_to_hub: 490 | if args.hub_model_id is None: 491 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 492 | else: 493 | repo_name = args.hub_model_id 494 | repo = Repository(args.output_dir, clone_from=repo_name) 495 | 496 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 497 | if "step_*" not in gitignore: 498 | gitignore.write("step_*\n") 499 | if "epoch_*" not in gitignore: 500 | gitignore.write("epoch_*\n") 501 | elif args.output_dir is not None: 502 | os.makedirs(args.output_dir, exist_ok=True) 503 | 504 | # Load tokenizer 505 | if args.tokenizer_name: 506 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 507 | elif args.pretrained_model_name_or_path: 508 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 509 | 510 | # Load scheduler and models 511 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 512 | text_encoder = CLIPTextModel.from_pretrained( 513 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 514 | ) 515 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) 516 | unet = UNet2DConditionModel.from_pretrained( 517 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision 518 | ) 519 | 520 | # Add the placeholder token in tokenizer 521 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) 522 | if num_added_tokens == 0: 523 | raise ValueError( 524 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" 525 | " `placeholder_token` that is not already in the tokenizer." 526 | ) 527 | 528 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 529 | 530 | # Resize the token embeddings as we are adding new special tokens to the tokenizer 531 | text_encoder.resize_token_embeddings(len(tokenizer)) 532 | 533 | # Initialise the newly added placeholder token with the embeddings of the initializer token 534 | token_embeds = text_encoder.get_input_embeddings().weight.data 535 | 536 | # Loading LEAP from checkpoint 537 | leap = leap_sd.LM.load_from_checkpoint(args.leap_model_path) 538 | leap = leap.to('cpu') 539 | leap.eval() 540 | 541 | boosted_embed = boost_embed(leap, args.train_data_dir) 542 | token_embeds[placeholder_token_id] = boosted_embed 543 | print(f"Successfully boosted embed to {boosted_embed}") 544 | 545 | # Freeze vae and unet 546 | vae.requires_grad_(False) 547 | unet.requires_grad_(False) 548 | # Freeze all parameters except for the token embeddings in text encoder 549 | text_encoder.text_model.encoder.requires_grad_(False) 550 | text_encoder.text_model.final_layer_norm.requires_grad_(False) 551 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False) 552 | 553 | if args.gradient_checkpointing: 554 | # Keep unet in train mode if we are using gradient checkpointing to save memory. 555 | # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode. 556 | unet.train() 557 | text_encoder.gradient_checkpointing_enable() 558 | unet.enable_gradient_checkpointing() 559 | 560 | if args.enable_xformers_memory_efficient_attention: 561 | if is_xformers_available(): 562 | unet.enable_xformers_memory_efficient_attention() 563 | else: 564 | raise ValueError("xformers is not available. Make sure it is installed correctly") 565 | 566 | # Enable TF32 for faster training on Ampere GPUs, 567 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 568 | if args.allow_tf32: 569 | torch.backends.cuda.matmul.allow_tf32 = True 570 | 571 | if args.scale_lr: 572 | args.learning_rate = ( 573 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 574 | ) 575 | 576 | # Initialize the optimizer 577 | optimizer = torch.optim.AdamW( 578 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings 579 | lr=args.learning_rate, 580 | betas=(args.adam_beta1, args.adam_beta2), 581 | weight_decay=args.adam_weight_decay, 582 | eps=args.adam_epsilon, 583 | ) 584 | 585 | # Dataset and DataLoaders creation: 586 | train_dataset = TextualInversionDataset( 587 | data_root=args.train_data_dir, 588 | tokenizer=tokenizer, 589 | size=args.resolution, 590 | placeholder_token=args.placeholder_token, 591 | repeats=args.repeats, 592 | learnable_property=args.learnable_property, 593 | center_crop=args.center_crop, 594 | set="train", 595 | ) 596 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) 597 | 598 | # Scheduler and math around the number of training steps. 599 | overrode_max_train_steps = False 600 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 601 | if args.max_train_steps is None: 602 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 603 | overrode_max_train_steps = True 604 | 605 | lr_scheduler = get_scheduler( 606 | args.lr_scheduler, 607 | optimizer=optimizer, 608 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 609 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 610 | ) 611 | 612 | # Prepare everything with our `accelerator`. 613 | text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 614 | text_encoder, optimizer, train_dataloader, lr_scheduler 615 | ) 616 | 617 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 618 | # as these models are only used for inference, keeping weights in full precision is not required. 619 | weight_dtype = torch.float32 620 | if accelerator.mixed_precision == "fp16": 621 | weight_dtype = torch.float16 622 | elif accelerator.mixed_precision == "bf16": 623 | weight_dtype = torch.bfloat16 624 | 625 | # Move vae and unet to device and cast to weight_dtype 626 | unet.to(accelerator.device, dtype=weight_dtype) 627 | vae.to(accelerator.device, dtype=weight_dtype) 628 | 629 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 630 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 631 | if overrode_max_train_steps: 632 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 633 | # Afterwards we recalculate our number of training epochs 634 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 635 | 636 | # We need to initialize the trackers we use, and also store our configuration. 637 | # The trackers initializes automatically on the main process. 638 | if accelerator.is_main_process: 639 | accelerator.init_trackers("textual_inversion", config=vars(args)) 640 | 641 | # Train! 642 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 643 | 644 | logger.info("***** Running training *****") 645 | logger.info(f" Num examples = {len(train_dataset)}") 646 | logger.info(f" Num Epochs = {args.num_train_epochs}") 647 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 648 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 649 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 650 | logger.info(f" Total optimization steps = {args.max_train_steps}") 651 | global_step = 0 652 | first_epoch = 0 653 | 654 | # Potentially load in the weights and states from a previous save 655 | if args.resume_from_checkpoint: 656 | if args.resume_from_checkpoint != "latest": 657 | path = os.path.basename(args.resume_from_checkpoint) 658 | else: 659 | # Get the most recent checkpoint 660 | dirs = os.listdir(args.output_dir) 661 | dirs = [d for d in dirs if d.startswith("checkpoint")] 662 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 663 | path = dirs[-1] 664 | accelerator.print(f"Resuming from checkpoint {path}") 665 | accelerator.load_state(os.path.join(args.output_dir, path)) 666 | global_step = int(path.split("-")[1]) 667 | 668 | resume_global_step = global_step * args.gradient_accumulation_steps 669 | first_epoch = resume_global_step // num_update_steps_per_epoch 670 | resume_step = resume_global_step % num_update_steps_per_epoch 671 | 672 | # Only show the progress bar once on each machine. 673 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 674 | progress_bar.set_description("Steps") 675 | 676 | # keep original embeddings as reference 677 | orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone() 678 | 679 | for epoch in range(first_epoch, args.num_train_epochs): 680 | text_encoder.train() 681 | for step, batch in enumerate(train_dataloader): 682 | # Skip steps until we reach the resumed step 683 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 684 | if step % args.gradient_accumulation_steps == 0: 685 | progress_bar.update(1) 686 | continue 687 | 688 | with accelerator.accumulate(text_encoder): 689 | # Convert images to latent space 690 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() 691 | latents = latents * 0.18215 692 | 693 | # Sample noise that we'll add to the latents 694 | noise = torch.randn_like(latents) 695 | bsz = latents.shape[0] 696 | # Sample a random timestep for each image 697 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 698 | timesteps = timesteps.long() 699 | 700 | # Add noise to the latents according to the noise magnitude at each timestep 701 | # (this is the forward diffusion process) 702 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 703 | 704 | # Get the text embedding for conditioning 705 | encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype) 706 | 707 | # Predict the noise residual 708 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 709 | 710 | # Get the target for loss depending on the prediction type 711 | if noise_scheduler.config.prediction_type == "epsilon": 712 | target = noise 713 | elif noise_scheduler.config.prediction_type == "v_prediction": 714 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 715 | else: 716 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 717 | 718 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 719 | 720 | accelerator.backward(loss) 721 | 722 | optimizer.step() 723 | lr_scheduler.step() 724 | optimizer.zero_grad() 725 | 726 | # Let's make sure we don't update any embedding weights besides the newly added token 727 | index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id 728 | with torch.no_grad(): 729 | accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ 730 | index_no_updates 731 | ] = orig_embeds_params[index_no_updates] 732 | 733 | # Checks if the accelerator has performed an optimization step behind the scenes 734 | if accelerator.sync_gradients: 735 | progress_bar.update(1) 736 | global_step += 1 737 | if global_step % args.save_steps == 0: 738 | save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin") 739 | save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) 740 | 741 | if global_step % args.checkpointing_steps == 0: 742 | if accelerator.is_main_process: 743 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 744 | accelerator.save_state(save_path) 745 | logger.info(f"Saved state to {save_path}") 746 | 747 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 748 | progress_bar.set_postfix(**logs) 749 | accelerator.log(logs, step=global_step) 750 | 751 | if global_step >= args.max_train_steps: 752 | break 753 | 754 | # Create the pipeline using using the trained modules and save it. 755 | accelerator.wait_for_everyone() 756 | if accelerator.is_main_process: 757 | if args.push_to_hub and args.only_save_embeds: 758 | logger.warn("Enabling full model saving because --push_to_hub=True was specified.") 759 | save_full_model = True 760 | else: 761 | save_full_model = not args.only_save_embeds 762 | if save_full_model: 763 | pipeline = StableDiffusionPipeline.from_pretrained( 764 | args.pretrained_model_name_or_path, 765 | text_encoder=accelerator.unwrap_model(text_encoder), 766 | vae=vae, 767 | unet=unet, 768 | tokenizer=tokenizer, 769 | ) 770 | pipeline.save_pretrained(args.output_dir) 771 | # Save the newly trained embeddings 772 | save_path = os.path.join(args.output_dir, "learned_embeds.bin") 773 | save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) 774 | 775 | if args.push_to_hub: 776 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) 777 | 778 | accelerator.end_training() 779 | 780 | 781 | if __name__ == "__main__": 782 | main() 783 | -------------------------------------------------------------------------------- /leap_sd/__init__.py: -------------------------------------------------------------------------------- 1 | from .module import LM -------------------------------------------------------------------------------- /leap_sd/module.py: -------------------------------------------------------------------------------- 1 | from .utils import linear_warmup_cosine_decay 2 | import pytorch_lightning as pl 3 | import torch 4 | import torchvision 5 | import random 6 | from itertools import chain 7 | from torch import nn, einsum 8 | import torch.nn.functional as F 9 | 10 | class LM(pl.LightningModule): 11 | def __init__( 12 | self, 13 | steps, 14 | input_shape, 15 | min_weight = 0, 16 | max_weight = 0, 17 | learning_rate=1e-4, 18 | weight_decay=0.0001, 19 | dropout_p=0.0, 20 | linear_warmup_ratio=0.01, 21 | latent_dim_size=1024, 22 | **_ 23 | ): 24 | super().__init__() 25 | self.save_hyperparameters() 26 | self.min_weight = min_weight 27 | self.max_weight = max_weight 28 | self.latent_dim_size = latent_dim_size 29 | self.learning_rate = learning_rate 30 | self.weight_decay = weight_decay 31 | self.steps = steps 32 | self.linear_warmup_ratio = linear_warmup_ratio 33 | self.criterion = torch.nn.L1Loss() 34 | self.init_model(input_shape, dropout_p) 35 | 36 | def init_model(self, input_shape, dropout_p): 37 | feature_layers = [ 38 | nn.Sequential( 39 | nn.Conv2d(3, 32, kernel_size=3, stride=2, padding=1), 40 | nn.LeakyReLU(), 41 | nn.MaxPool2d(kernel_size=2, stride=2), 42 | nn.Dropout(p=dropout_p) 43 | ), 44 | nn.Sequential( 45 | nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1), 46 | nn.LeakyReLU(), 47 | nn.MaxPool2d(kernel_size=2, stride=2), 48 | nn.Dropout(p=dropout_p) 49 | ), 50 | nn.Sequential( 51 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), 52 | nn.LeakyReLU(), 53 | nn.MaxPool2d(kernel_size=2, stride=2), 54 | nn.Dropout(p=dropout_p) 55 | ) 56 | ] 57 | self.features = nn.Sequential(*feature_layers) 58 | n_sizes = self._get_conv_output(input_shape) 59 | output_layers = [ 60 | nn.Linear(n_sizes, self.latent_dim_size), 61 | nn.LeakyReLU(), 62 | nn.Dropout(p=dropout_p), 63 | nn.Linear(self.latent_dim_size, self.latent_dim_size) 64 | ] 65 | self.output = nn.Sequential(*output_layers) 66 | self.forget_leveler = nn.Linear(n_sizes, 1) 67 | 68 | # returns the size of the output tensor going into Linear layer from the conv block. 69 | def _get_conv_output(self, shape): 70 | batch_size = 1 71 | input = torch.autograd.Variable(torch.rand(batch_size, *shape)) 72 | 73 | output_feat = self.features(input) 74 | n_size = output_feat.data.view(batch_size, -1).size(1) 75 | return n_size 76 | 77 | # will be used during inference 78 | def forward(self, x): 79 | images_len = x.shape[1] 80 | xf = None 81 | for i in range(images_len): 82 | image_selection = x[:, i, ...] 83 | if xf is None: 84 | xf = self.features(image_selection) 85 | else: 86 | xf += self.features(image_selection) 87 | xf = xf / images_len 88 | xf = xf.view(xf.size(0), -1) 89 | xfo = self.forget_leveler(xf) 90 | xf[xf < xfo] = 0 91 | xf = self.output(xf) 92 | xf = self.denormalize_embed(xf) 93 | return xf 94 | 95 | def configure_optimizers(self): 96 | optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.0) 97 | warmup_steps = int(self.linear_warmup_ratio * self.steps) 98 | scheduler = { 99 | "scheduler": linear_warmup_cosine_decay(optimizer, warmup_steps, self.steps), 100 | "interval": "step", 101 | } 102 | return [optimizer], [scheduler] 103 | 104 | def denormalize_embed(self, embed): 105 | embed = embed * (abs(self.min_weight) + self.max_weight) 106 | embed = embed - abs(self.min_weight) 107 | return embed 108 | 109 | def shot(self, batch, name, image_logging = False): 110 | image_grid, target = batch 111 | pred = self.forward(image_grid) 112 | loss = self.criterion(pred, target) 113 | self.log(f"{name}_loss", loss) 114 | return loss 115 | 116 | def training_step(self, batch, batch_idx): 117 | return self.shot(batch, "train", image_logging = True) 118 | 119 | def validation_step(self, batch, batch_idx): 120 | return self.shot(batch, "val") -------------------------------------------------------------------------------- /leap_sd/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from einops import rearrange 4 | import random 5 | 6 | def linear_warmup_cosine_decay(optimizer, warmup_steps, total_steps): 7 | """ 8 | Linear warmup for warmup_steps, with cosine annealing to 0 at total_steps 9 | """ 10 | 11 | def fn(step): 12 | if step < warmup_steps: 13 | return float(step) / float(max(1, warmup_steps)) 14 | 15 | progress = float(step - warmup_steps) / float( 16 | max(1, total_steps - warmup_steps) 17 | ) 18 | return 0.5 * (1.0 + math.cos(math.pi * progress)) 19 | 20 | return torch.optim.lr_scheduler.LambdaLR( 21 | optimizer, fn 22 | ) 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | 4 | # Utility function to read the README file. 5 | # Used for the long_description. It's nice, because now 1) we have a top level 6 | # README file and 2) it's easier to type in the README file than to put a raw 7 | # string in below ... 8 | def read(fname): 9 | return open(os.path.join(os.path.dirname(__file__), fname)).read() 10 | 11 | setup( 12 | name = "leap_sd", 13 | version = "0.0.2", 14 | author = "Peter Willemsen", 15 | author_email = "peter@codebuffet.co", 16 | description = "Fast finetuning of Stable Diffusion using LEAP booster model.", 17 | license = "MIT", 18 | keywords = "finetuning training stable-diffusion huggingface", 19 | url = "https://github.com/peterwilli/sd-leap-booster", 20 | packages=['leap_sd'], 21 | long_description=read('README.md'), 22 | scripts=['bin/leap_textual_inversion'], 23 | install_requires=[ 24 | 'numpy', 25 | 'diffusers', 26 | 'transformers', 27 | 'datasets', 28 | 'torchvision', 29 | 'accelerate', 30 | 'pytorch_lightning', 31 | 'tensorboard' 32 | ], 33 | classifiers=[ 34 | "Development Status :: 3 - Alpha", 35 | "Topic :: Utilities", 36 | "License :: OSI Approved :: MIT License", 37 | ], 38 | package_data={'': ['leap_sd/model.ckpt']}, 39 | include_package_data=True 40 | ) 41 | -------------------------------------------------------------------------------- /training/README.md: -------------------------------------------------------------------------------- 1 | # Training LEAP 2 | 3 | If you want to use LEAP on a different Stable Diffusion version, or simply learn how it works, then you find the instructions below very useful. 4 | 5 | **Authors notes**: I'm using Linux (NixOS), while it could work on Windows, I'm not familiar with this operating system. Feel free to reach out for help. Commands are `written like this!`. 6 | 7 | ## Dataset creation 8 | 9 | LEAP uses a synthetic dataset, we extract all words from Stable Diffusion and generate "samples" that allow us to associate its images with the weights used to make them. 10 | 11 | - Clone this repository and `cd` to it. 12 | - Install leap_sd: `pip install -e .` 13 | - run `python training/dataset_creator/sd_extractor.py` 14 | - For using a custom Stable Diffusion, pass `--pretrained_model_name_or_path`, for example: `--pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5` 15 | 16 | ## Train! 🐲 17 | 18 | - `cd` to the training directory of this repo 19 | - Check the size of your latent space! (for sd < 2.0 it's 768 and for >= 2.0 it's 1024) 20 | - Run the following training command: `python training/train.py --batch_size=10 --gpus=1 --max_epochs=250 --latent_dim_size=1024` 21 | 22 | ### Examples 23 | 24 | - Training for SD 1.5 25 | 26 | ```bash 27 | python training/dataset_creator/sd_extractor.py --pretrained_model_name_or_path=runwayml/stable-diffusion-v1-5 28 | python training/train.py --batch_size=10 --gpus=1 --max_epochs=250 --latent_dim_size=768 29 | ``` 30 | 31 | # Support, sponsorship and thanks 32 | 33 | Are you looking to make a positive impact and get some awesome perks in the process? **[Join me on Patreon!](https://www.patreon.com/emerald_show)** For just $3 per month, you can join our Patreon community and help a creative mind in the Netherlands bring their ideas to life. 34 | 35 | Not only will you get the satisfaction of supporting an individual's passions, but you'll also receive a 50% discount on any paid services that result from the projects you sponsor. Plus, as a Patreon member, you'll have exclusive voting rights on new features and the opportunity to shape the direction of future projects. Don't miss out on this chance to make a difference and get some amazing benefits in return. 36 | 37 | One of the things we intend on doing, is trying to make LEAP with Lora! -------------------------------------------------------------------------------- /training/dataset_creator/bip39.txt: -------------------------------------------------------------------------------- 1 | abandon 2 | ability 3 | able 4 | about 5 | above 6 | absent 7 | absorb 8 | abstract 9 | absurd 10 | abuse 11 | access 12 | accident 13 | account 14 | accuse 15 | achieve 16 | acid 17 | acoustic 18 | acquire 19 | across 20 | act 21 | action 22 | actor 23 | actress 24 | actual 25 | adapt 26 | add 27 | addict 28 | address 29 | adjust 30 | admit 31 | adult 32 | advance 33 | advice 34 | aerobic 35 | affair 36 | afford 37 | afraid 38 | again 39 | age 40 | agent 41 | agree 42 | ahead 43 | aim 44 | air 45 | airport 46 | aisle 47 | alarm 48 | album 49 | alcohol 50 | alert 51 | alien 52 | all 53 | alley 54 | allow 55 | almost 56 | alone 57 | alpha 58 | already 59 | also 60 | alter 61 | always 62 | amateur 63 | amazing 64 | among 65 | amount 66 | amused 67 | analyst 68 | anchor 69 | ancient 70 | anger 71 | angle 72 | angry 73 | animal 74 | ankle 75 | announce 76 | annual 77 | another 78 | answer 79 | antenna 80 | antique 81 | anxiety 82 | any 83 | apart 84 | apology 85 | appear 86 | apple 87 | approve 88 | april 89 | arch 90 | arctic 91 | area 92 | arena 93 | argue 94 | arm 95 | armed 96 | armor 97 | army 98 | around 99 | arrange 100 | arrest 101 | arrive 102 | arrow 103 | art 104 | artefact 105 | artist 106 | artwork 107 | ask 108 | aspect 109 | assault 110 | asset 111 | assist 112 | assume 113 | asthma 114 | athlete 115 | atom 116 | attack 117 | attend 118 | attitude 119 | attract 120 | auction 121 | audit 122 | august 123 | aunt 124 | author 125 | auto 126 | autumn 127 | average 128 | avocado 129 | avoid 130 | awake 131 | aware 132 | away 133 | awesome 134 | awful 135 | awkward 136 | axis 137 | baby 138 | bachelor 139 | bacon 140 | badge 141 | bag 142 | balance 143 | balcony 144 | ball 145 | bamboo 146 | banana 147 | banner 148 | bar 149 | barely 150 | bargain 151 | barrel 152 | base 153 | basic 154 | basket 155 | battle 156 | beach 157 | bean 158 | beauty 159 | because 160 | become 161 | beef 162 | before 163 | begin 164 | behave 165 | behind 166 | believe 167 | below 168 | belt 169 | bench 170 | benefit 171 | best 172 | betray 173 | better 174 | between 175 | beyond 176 | bicycle 177 | bid 178 | bike 179 | bind 180 | biology 181 | bird 182 | birth 183 | bitter 184 | black 185 | blade 186 | blame 187 | blanket 188 | blast 189 | bleak 190 | bless 191 | blind 192 | blood 193 | blossom 194 | blouse 195 | blue 196 | blur 197 | blush 198 | board 199 | boat 200 | body 201 | boil 202 | bomb 203 | bone 204 | bonus 205 | book 206 | boost 207 | border 208 | boring 209 | borrow 210 | boss 211 | bottom 212 | bounce 213 | box 214 | boy 215 | bracket 216 | brain 217 | brand 218 | brass 219 | brave 220 | bread 221 | breeze 222 | brick 223 | bridge 224 | brief 225 | bright 226 | bring 227 | brisk 228 | broccoli 229 | broken 230 | bronze 231 | broom 232 | brother 233 | brown 234 | brush 235 | bubble 236 | buddy 237 | budget 238 | buffalo 239 | build 240 | bulb 241 | bulk 242 | bullet 243 | bundle 244 | bunker 245 | burden 246 | burger 247 | burst 248 | bus 249 | business 250 | busy 251 | butter 252 | buyer 253 | buzz 254 | cabbage 255 | cabin 256 | cable 257 | cactus 258 | cage 259 | cake 260 | call 261 | calm 262 | camera 263 | camp 264 | can 265 | canal 266 | cancel 267 | candy 268 | cannon 269 | canoe 270 | canvas 271 | canyon 272 | capable 273 | capital 274 | captain 275 | car 276 | carbon 277 | card 278 | cargo 279 | carpet 280 | carry 281 | cart 282 | case 283 | cash 284 | casino 285 | castle 286 | casual 287 | cat 288 | catalog 289 | catch 290 | category 291 | cattle 292 | caught 293 | cause 294 | caution 295 | cave 296 | ceiling 297 | celery 298 | cement 299 | census 300 | century 301 | cereal 302 | certain 303 | chair 304 | chalk 305 | champion 306 | change 307 | chaos 308 | chapter 309 | charge 310 | chase 311 | chat 312 | cheap 313 | check 314 | cheese 315 | chef 316 | cherry 317 | chest 318 | chicken 319 | chief 320 | child 321 | chimney 322 | choice 323 | choose 324 | chronic 325 | chuckle 326 | chunk 327 | churn 328 | cigar 329 | cinnamon 330 | circle 331 | citizen 332 | city 333 | civil 334 | claim 335 | clap 336 | clarify 337 | claw 338 | clay 339 | clean 340 | clerk 341 | clever 342 | click 343 | client 344 | cliff 345 | climb 346 | clinic 347 | clip 348 | clock 349 | clog 350 | close 351 | cloth 352 | cloud 353 | clown 354 | club 355 | clump 356 | cluster 357 | clutch 358 | coach 359 | coast 360 | coconut 361 | code 362 | coffee 363 | coil 364 | coin 365 | collect 366 | color 367 | column 368 | combine 369 | come 370 | comfort 371 | comic 372 | common 373 | company 374 | concert 375 | conduct 376 | confirm 377 | congress 378 | connect 379 | consider 380 | control 381 | convince 382 | cook 383 | cool 384 | copper 385 | copy 386 | coral 387 | core 388 | corn 389 | correct 390 | cost 391 | cotton 392 | couch 393 | country 394 | couple 395 | course 396 | cousin 397 | cover 398 | coyote 399 | crack 400 | cradle 401 | craft 402 | cram 403 | crane 404 | crash 405 | crater 406 | crawl 407 | crazy 408 | cream 409 | credit 410 | creek 411 | crew 412 | cricket 413 | crime 414 | crisp 415 | critic 416 | crop 417 | cross 418 | crouch 419 | crowd 420 | crucial 421 | cruel 422 | cruise 423 | crumble 424 | crunch 425 | crush 426 | cry 427 | crystal 428 | cube 429 | culture 430 | cup 431 | cupboard 432 | curious 433 | current 434 | curtain 435 | curve 436 | cushion 437 | custom 438 | cute 439 | cycle 440 | dad 441 | damage 442 | damp 443 | dance 444 | danger 445 | daring 446 | dash 447 | daughter 448 | dawn 449 | day 450 | deal 451 | debate 452 | debris 453 | decade 454 | december 455 | decide 456 | decline 457 | decorate 458 | decrease 459 | deer 460 | defense 461 | define 462 | defy 463 | degree 464 | delay 465 | deliver 466 | demand 467 | demise 468 | denial 469 | dentist 470 | deny 471 | depart 472 | depend 473 | deposit 474 | depth 475 | deputy 476 | derive 477 | describe 478 | desert 479 | design 480 | desk 481 | despair 482 | destroy 483 | detail 484 | detect 485 | develop 486 | device 487 | devote 488 | diagram 489 | dial 490 | diamond 491 | diary 492 | dice 493 | diesel 494 | diet 495 | differ 496 | digital 497 | dignity 498 | dilemma 499 | dinner 500 | dinosaur 501 | direct 502 | dirt 503 | disagree 504 | discover 505 | disease 506 | dish 507 | dismiss 508 | disorder 509 | display 510 | distance 511 | divert 512 | divide 513 | divorce 514 | dizzy 515 | doctor 516 | document 517 | dog 518 | doll 519 | dolphin 520 | domain 521 | donate 522 | donkey 523 | donor 524 | door 525 | dose 526 | double 527 | dove 528 | draft 529 | dragon 530 | drama 531 | drastic 532 | draw 533 | dream 534 | dress 535 | drift 536 | drill 537 | drink 538 | drip 539 | drive 540 | drop 541 | drum 542 | dry 543 | duck 544 | dumb 545 | dune 546 | during 547 | dust 548 | dutch 549 | duty 550 | dwarf 551 | dynamic 552 | eager 553 | eagle 554 | early 555 | earn 556 | earth 557 | easily 558 | east 559 | easy 560 | echo 561 | ecology 562 | economy 563 | edge 564 | edit 565 | educate 566 | effort 567 | egg 568 | eight 569 | either 570 | elbow 571 | elder 572 | electric 573 | elegant 574 | element 575 | elephant 576 | elevator 577 | elite 578 | else 579 | embark 580 | embody 581 | embrace 582 | emerge 583 | emotion 584 | employ 585 | empower 586 | empty 587 | enable 588 | enact 589 | end 590 | endless 591 | endorse 592 | enemy 593 | energy 594 | enforce 595 | engage 596 | engine 597 | enhance 598 | enjoy 599 | enlist 600 | enough 601 | enrich 602 | enroll 603 | ensure 604 | enter 605 | entire 606 | entry 607 | envelope 608 | episode 609 | equal 610 | equip 611 | era 612 | erase 613 | erode 614 | erosion 615 | error 616 | erupt 617 | escape 618 | essay 619 | essence 620 | estate 621 | eternal 622 | ethics 623 | evidence 624 | evil 625 | evoke 626 | evolve 627 | exact 628 | example 629 | excess 630 | exchange 631 | excite 632 | exclude 633 | excuse 634 | execute 635 | exercise 636 | exhaust 637 | exhibit 638 | exile 639 | exist 640 | exit 641 | exotic 642 | expand 643 | expect 644 | expire 645 | explain 646 | expose 647 | express 648 | extend 649 | extra 650 | eye 651 | eyebrow 652 | fabric 653 | face 654 | faculty 655 | fade 656 | faint 657 | faith 658 | fall 659 | false 660 | fame 661 | family 662 | famous 663 | fan 664 | fancy 665 | fantasy 666 | farm 667 | fashion 668 | fat 669 | fatal 670 | father 671 | fatigue 672 | fault 673 | favorite 674 | feature 675 | february 676 | federal 677 | fee 678 | feed 679 | feel 680 | female 681 | fence 682 | festival 683 | fetch 684 | fever 685 | few 686 | fiber 687 | fiction 688 | field 689 | figure 690 | file 691 | film 692 | filter 693 | final 694 | find 695 | fine 696 | finger 697 | finish 698 | fire 699 | firm 700 | first 701 | fiscal 702 | fish 703 | fit 704 | fitness 705 | fix 706 | flag 707 | flame 708 | flash 709 | flat 710 | flavor 711 | flee 712 | flight 713 | flip 714 | float 715 | flock 716 | floor 717 | flower 718 | fluid 719 | flush 720 | fly 721 | foam 722 | focus 723 | fog 724 | foil 725 | fold 726 | follow 727 | food 728 | foot 729 | force 730 | forest 731 | forget 732 | fork 733 | fortune 734 | forum 735 | forward 736 | fossil 737 | foster 738 | found 739 | fox 740 | fragile 741 | frame 742 | frequent 743 | fresh 744 | friend 745 | fringe 746 | frog 747 | front 748 | frost 749 | frown 750 | frozen 751 | fruit 752 | fuel 753 | fun 754 | funny 755 | furnace 756 | fury 757 | future 758 | gadget 759 | gain 760 | galaxy 761 | gallery 762 | game 763 | gap 764 | garage 765 | garbage 766 | garden 767 | garlic 768 | garment 769 | gas 770 | gasp 771 | gate 772 | gather 773 | gauge 774 | gaze 775 | general 776 | genius 777 | genre 778 | gentle 779 | genuine 780 | gesture 781 | ghost 782 | giant 783 | gift 784 | giggle 785 | ginger 786 | giraffe 787 | girl 788 | give 789 | glad 790 | glance 791 | glare 792 | glass 793 | glide 794 | glimpse 795 | globe 796 | gloom 797 | glory 798 | glove 799 | glow 800 | glue 801 | goat 802 | goddess 803 | gold 804 | good 805 | goose 806 | gorilla 807 | gospel 808 | gossip 809 | govern 810 | gown 811 | grab 812 | grace 813 | grain 814 | grant 815 | grape 816 | grass 817 | gravity 818 | great 819 | green 820 | grid 821 | grief 822 | grit 823 | grocery 824 | group 825 | grow 826 | grunt 827 | guard 828 | guess 829 | guide 830 | guilt 831 | guitar 832 | gun 833 | gym 834 | habit 835 | hair 836 | half 837 | hammer 838 | hamster 839 | hand 840 | happy 841 | harbor 842 | hard 843 | harsh 844 | harvest 845 | hat 846 | have 847 | hawk 848 | hazard 849 | head 850 | health 851 | heart 852 | heavy 853 | hedgehog 854 | height 855 | hello 856 | helmet 857 | help 858 | hen 859 | hero 860 | hidden 861 | high 862 | hill 863 | hint 864 | hip 865 | hire 866 | history 867 | hobby 868 | hockey 869 | hold 870 | hole 871 | holiday 872 | hollow 873 | home 874 | honey 875 | hood 876 | hope 877 | horn 878 | horror 879 | horse 880 | hospital 881 | host 882 | hotel 883 | hour 884 | hover 885 | hub 886 | huge 887 | human 888 | humble 889 | humor 890 | hundred 891 | hungry 892 | hunt 893 | hurdle 894 | hurry 895 | hurt 896 | husband 897 | hybrid 898 | ice 899 | icon 900 | idea 901 | identify 902 | idle 903 | ignore 904 | ill 905 | illegal 906 | illness 907 | image 908 | imitate 909 | immense 910 | immune 911 | impact 912 | impose 913 | improve 914 | impulse 915 | inch 916 | include 917 | income 918 | increase 919 | index 920 | indicate 921 | indoor 922 | industry 923 | infant 924 | inflict 925 | inform 926 | inhale 927 | inherit 928 | initial 929 | inject 930 | injury 931 | inmate 932 | inner 933 | innocent 934 | input 935 | inquiry 936 | insane 937 | insect 938 | inside 939 | inspire 940 | install 941 | intact 942 | interest 943 | into 944 | invest 945 | invite 946 | involve 947 | iron 948 | island 949 | isolate 950 | issue 951 | item 952 | ivory 953 | jacket 954 | jaguar 955 | jar 956 | jazz 957 | jealous 958 | jeans 959 | jelly 960 | jewel 961 | job 962 | join 963 | joke 964 | journey 965 | joy 966 | judge 967 | juice 968 | jump 969 | jungle 970 | junior 971 | junk 972 | just 973 | kangaroo 974 | keen 975 | keep 976 | ketchup 977 | key 978 | kick 979 | kid 980 | kidney 981 | kind 982 | kingdom 983 | kiss 984 | kit 985 | kitchen 986 | kite 987 | kitten 988 | kiwi 989 | knee 990 | knife 991 | knock 992 | know 993 | lab 994 | label 995 | labor 996 | ladder 997 | lady 998 | lake 999 | lamp 1000 | language 1001 | laptop 1002 | large 1003 | later 1004 | latin 1005 | laugh 1006 | laundry 1007 | lava 1008 | law 1009 | lawn 1010 | lawsuit 1011 | layer 1012 | lazy 1013 | leader 1014 | leaf 1015 | learn 1016 | leave 1017 | lecture 1018 | left 1019 | leg 1020 | legal 1021 | legend 1022 | leisure 1023 | lemon 1024 | lend 1025 | length 1026 | lens 1027 | leopard 1028 | lesson 1029 | letter 1030 | level 1031 | liar 1032 | liberty 1033 | library 1034 | license 1035 | life 1036 | lift 1037 | light 1038 | like 1039 | limb 1040 | limit 1041 | link 1042 | lion 1043 | liquid 1044 | list 1045 | little 1046 | live 1047 | lizard 1048 | load 1049 | loan 1050 | lobster 1051 | local 1052 | lock 1053 | logic 1054 | lonely 1055 | long 1056 | loop 1057 | lottery 1058 | loud 1059 | lounge 1060 | love 1061 | loyal 1062 | lucky 1063 | luggage 1064 | lumber 1065 | lunar 1066 | lunch 1067 | luxury 1068 | lyrics 1069 | machine 1070 | mad 1071 | magic 1072 | magnet 1073 | maid 1074 | mail 1075 | main 1076 | major 1077 | make 1078 | mammal 1079 | man 1080 | manage 1081 | mandate 1082 | mango 1083 | mansion 1084 | manual 1085 | maple 1086 | marble 1087 | march 1088 | margin 1089 | marine 1090 | market 1091 | marriage 1092 | mask 1093 | mass 1094 | master 1095 | match 1096 | material 1097 | math 1098 | matrix 1099 | matter 1100 | maximum 1101 | maze 1102 | meadow 1103 | mean 1104 | measure 1105 | meat 1106 | mechanic 1107 | medal 1108 | media 1109 | melody 1110 | melt 1111 | member 1112 | memory 1113 | mention 1114 | menu 1115 | mercy 1116 | merge 1117 | merit 1118 | merry 1119 | mesh 1120 | message 1121 | metal 1122 | method 1123 | middle 1124 | midnight 1125 | milk 1126 | million 1127 | mimic 1128 | mind 1129 | minimum 1130 | minor 1131 | minute 1132 | miracle 1133 | mirror 1134 | misery 1135 | miss 1136 | mistake 1137 | mix 1138 | mixed 1139 | mixture 1140 | mobile 1141 | model 1142 | modify 1143 | mom 1144 | moment 1145 | monitor 1146 | monkey 1147 | monster 1148 | month 1149 | moon 1150 | moral 1151 | more 1152 | morning 1153 | mosquito 1154 | mother 1155 | motion 1156 | motor 1157 | mountain 1158 | mouse 1159 | move 1160 | movie 1161 | much 1162 | muffin 1163 | mule 1164 | multiply 1165 | muscle 1166 | museum 1167 | mushroom 1168 | music 1169 | must 1170 | mutual 1171 | myself 1172 | mystery 1173 | myth 1174 | naive 1175 | name 1176 | napkin 1177 | narrow 1178 | nasty 1179 | nation 1180 | nature 1181 | near 1182 | neck 1183 | need 1184 | negative 1185 | neglect 1186 | neither 1187 | nephew 1188 | nerve 1189 | nest 1190 | net 1191 | network 1192 | neutral 1193 | never 1194 | news 1195 | next 1196 | nice 1197 | night 1198 | noble 1199 | noise 1200 | nominee 1201 | noodle 1202 | normal 1203 | north 1204 | nose 1205 | notable 1206 | note 1207 | nothing 1208 | notice 1209 | novel 1210 | now 1211 | nuclear 1212 | number 1213 | nurse 1214 | nut 1215 | oak 1216 | obey 1217 | object 1218 | oblige 1219 | obscure 1220 | observe 1221 | obtain 1222 | obvious 1223 | occur 1224 | ocean 1225 | october 1226 | odor 1227 | off 1228 | offer 1229 | office 1230 | often 1231 | oil 1232 | okay 1233 | old 1234 | olive 1235 | olympic 1236 | omit 1237 | once 1238 | one 1239 | onion 1240 | online 1241 | only 1242 | open 1243 | opera 1244 | opinion 1245 | oppose 1246 | option 1247 | orange 1248 | orbit 1249 | orchard 1250 | order 1251 | ordinary 1252 | organ 1253 | orient 1254 | original 1255 | orphan 1256 | ostrich 1257 | other 1258 | outdoor 1259 | outer 1260 | output 1261 | outside 1262 | oval 1263 | oven 1264 | over 1265 | own 1266 | owner 1267 | oxygen 1268 | oyster 1269 | ozone 1270 | pact 1271 | paddle 1272 | page 1273 | pair 1274 | palace 1275 | palm 1276 | panda 1277 | panel 1278 | panic 1279 | panther 1280 | paper 1281 | parade 1282 | parent 1283 | park 1284 | parrot 1285 | party 1286 | pass 1287 | patch 1288 | path 1289 | patient 1290 | patrol 1291 | pattern 1292 | pause 1293 | pave 1294 | payment 1295 | peace 1296 | peanut 1297 | pear 1298 | peasant 1299 | pelican 1300 | pen 1301 | penalty 1302 | pencil 1303 | people 1304 | pepper 1305 | perfect 1306 | permit 1307 | person 1308 | pet 1309 | phone 1310 | photo 1311 | phrase 1312 | physical 1313 | piano 1314 | picnic 1315 | picture 1316 | piece 1317 | pig 1318 | pigeon 1319 | pill 1320 | pilot 1321 | pink 1322 | pioneer 1323 | pipe 1324 | pistol 1325 | pitch 1326 | pizza 1327 | place 1328 | planet 1329 | plastic 1330 | plate 1331 | play 1332 | please 1333 | pledge 1334 | pluck 1335 | plug 1336 | plunge 1337 | poem 1338 | poet 1339 | point 1340 | polar 1341 | pole 1342 | police 1343 | pond 1344 | pony 1345 | pool 1346 | popular 1347 | portion 1348 | position 1349 | possible 1350 | post 1351 | potato 1352 | pottery 1353 | poverty 1354 | powder 1355 | power 1356 | practice 1357 | praise 1358 | predict 1359 | prefer 1360 | prepare 1361 | present 1362 | pretty 1363 | prevent 1364 | price 1365 | pride 1366 | primary 1367 | print 1368 | priority 1369 | prison 1370 | private 1371 | prize 1372 | problem 1373 | process 1374 | produce 1375 | profit 1376 | program 1377 | project 1378 | promote 1379 | proof 1380 | property 1381 | prosper 1382 | protect 1383 | proud 1384 | provide 1385 | public 1386 | pudding 1387 | pull 1388 | pulp 1389 | pulse 1390 | pumpkin 1391 | punch 1392 | pupil 1393 | puppy 1394 | purchase 1395 | purity 1396 | purpose 1397 | purse 1398 | push 1399 | put 1400 | puzzle 1401 | pyramid 1402 | quality 1403 | quantum 1404 | quarter 1405 | question 1406 | quick 1407 | quit 1408 | quiz 1409 | quote 1410 | rabbit 1411 | raccoon 1412 | race 1413 | rack 1414 | radar 1415 | radio 1416 | rail 1417 | rain 1418 | raise 1419 | rally 1420 | ramp 1421 | ranch 1422 | random 1423 | range 1424 | rapid 1425 | rare 1426 | rate 1427 | rather 1428 | raven 1429 | raw 1430 | razor 1431 | ready 1432 | real 1433 | reason 1434 | rebel 1435 | rebuild 1436 | recall 1437 | receive 1438 | recipe 1439 | record 1440 | recycle 1441 | reduce 1442 | reflect 1443 | reform 1444 | refuse 1445 | region 1446 | regret 1447 | regular 1448 | reject 1449 | relax 1450 | release 1451 | relief 1452 | rely 1453 | remain 1454 | remember 1455 | remind 1456 | remove 1457 | render 1458 | renew 1459 | rent 1460 | reopen 1461 | repair 1462 | repeat 1463 | replace 1464 | report 1465 | require 1466 | rescue 1467 | resemble 1468 | resist 1469 | resource 1470 | response 1471 | result 1472 | retire 1473 | retreat 1474 | return 1475 | reunion 1476 | reveal 1477 | review 1478 | reward 1479 | rhythm 1480 | rib 1481 | ribbon 1482 | rice 1483 | rich 1484 | ride 1485 | ridge 1486 | rifle 1487 | right 1488 | rigid 1489 | ring 1490 | riot 1491 | ripple 1492 | risk 1493 | ritual 1494 | rival 1495 | river 1496 | road 1497 | roast 1498 | robot 1499 | robust 1500 | rocket 1501 | romance 1502 | roof 1503 | rookie 1504 | room 1505 | rose 1506 | rotate 1507 | rough 1508 | round 1509 | route 1510 | royal 1511 | rubber 1512 | rude 1513 | rug 1514 | rule 1515 | run 1516 | runway 1517 | rural 1518 | sad 1519 | saddle 1520 | sadness 1521 | safe 1522 | sail 1523 | salad 1524 | salmon 1525 | salon 1526 | salt 1527 | salute 1528 | same 1529 | sample 1530 | sand 1531 | satisfy 1532 | satoshi 1533 | sauce 1534 | sausage 1535 | save 1536 | say 1537 | scale 1538 | scan 1539 | scare 1540 | scatter 1541 | scene 1542 | scheme 1543 | school 1544 | science 1545 | scissors 1546 | scorpion 1547 | scout 1548 | scrap 1549 | screen 1550 | script 1551 | scrub 1552 | sea 1553 | search 1554 | season 1555 | seat 1556 | second 1557 | secret 1558 | section 1559 | security 1560 | seed 1561 | seek 1562 | segment 1563 | select 1564 | sell 1565 | seminar 1566 | senior 1567 | sense 1568 | sentence 1569 | series 1570 | service 1571 | session 1572 | settle 1573 | setup 1574 | seven 1575 | shadow 1576 | shaft 1577 | shallow 1578 | share 1579 | shed 1580 | shell 1581 | sheriff 1582 | shield 1583 | shift 1584 | shine 1585 | ship 1586 | shiver 1587 | shock 1588 | shoe 1589 | shoot 1590 | shop 1591 | short 1592 | shoulder 1593 | shove 1594 | shrimp 1595 | shrug 1596 | shuffle 1597 | shy 1598 | sibling 1599 | sick 1600 | side 1601 | siege 1602 | sight 1603 | sign 1604 | silent 1605 | silk 1606 | silly 1607 | silver 1608 | similar 1609 | simple 1610 | since 1611 | sing 1612 | siren 1613 | sister 1614 | situate 1615 | six 1616 | size 1617 | skate 1618 | sketch 1619 | ski 1620 | skill 1621 | skin 1622 | skirt 1623 | skull 1624 | slab 1625 | slam 1626 | sleep 1627 | slender 1628 | slice 1629 | slide 1630 | slight 1631 | slim 1632 | slogan 1633 | slot 1634 | slow 1635 | slush 1636 | small 1637 | smart 1638 | smile 1639 | smoke 1640 | smooth 1641 | snack 1642 | snake 1643 | snap 1644 | sniff 1645 | snow 1646 | soap 1647 | soccer 1648 | social 1649 | sock 1650 | soda 1651 | soft 1652 | solar 1653 | soldier 1654 | solid 1655 | solution 1656 | solve 1657 | someone 1658 | song 1659 | soon 1660 | sorry 1661 | sort 1662 | soul 1663 | sound 1664 | soup 1665 | source 1666 | south 1667 | space 1668 | spare 1669 | spatial 1670 | spawn 1671 | speak 1672 | special 1673 | speed 1674 | spell 1675 | spend 1676 | sphere 1677 | spice 1678 | spider 1679 | spike 1680 | spin 1681 | spirit 1682 | split 1683 | spoil 1684 | sponsor 1685 | spoon 1686 | sport 1687 | spot 1688 | spray 1689 | spread 1690 | spring 1691 | spy 1692 | square 1693 | squeeze 1694 | squirrel 1695 | stable 1696 | stadium 1697 | staff 1698 | stage 1699 | stairs 1700 | stamp 1701 | stand 1702 | start 1703 | state 1704 | stay 1705 | steak 1706 | steel 1707 | stem 1708 | step 1709 | stereo 1710 | stick 1711 | still 1712 | sting 1713 | stock 1714 | stomach 1715 | stone 1716 | stool 1717 | story 1718 | stove 1719 | strategy 1720 | street 1721 | strike 1722 | strong 1723 | struggle 1724 | student 1725 | stuff 1726 | stumble 1727 | style 1728 | subject 1729 | submit 1730 | subway 1731 | success 1732 | such 1733 | sudden 1734 | suffer 1735 | sugar 1736 | suggest 1737 | suit 1738 | summer 1739 | sun 1740 | sunny 1741 | sunset 1742 | super 1743 | supply 1744 | supreme 1745 | sure 1746 | surface 1747 | surge 1748 | surprise 1749 | surround 1750 | survey 1751 | suspect 1752 | sustain 1753 | swallow 1754 | swamp 1755 | swap 1756 | swarm 1757 | swear 1758 | sweet 1759 | swift 1760 | swim 1761 | swing 1762 | switch 1763 | sword 1764 | symbol 1765 | symptom 1766 | syrup 1767 | system 1768 | table 1769 | tackle 1770 | tag 1771 | tail 1772 | talent 1773 | talk 1774 | tank 1775 | tape 1776 | target 1777 | task 1778 | taste 1779 | tattoo 1780 | taxi 1781 | teach 1782 | team 1783 | tell 1784 | ten 1785 | tenant 1786 | tennis 1787 | tent 1788 | term 1789 | test 1790 | text 1791 | thank 1792 | that 1793 | theme 1794 | then 1795 | theory 1796 | there 1797 | they 1798 | thing 1799 | this 1800 | thought 1801 | three 1802 | thrive 1803 | throw 1804 | thumb 1805 | thunder 1806 | ticket 1807 | tide 1808 | tiger 1809 | tilt 1810 | timber 1811 | time 1812 | tiny 1813 | tip 1814 | tired 1815 | tissue 1816 | title 1817 | toast 1818 | tobacco 1819 | today 1820 | toddler 1821 | toe 1822 | together 1823 | toilet 1824 | token 1825 | tomato 1826 | tomorrow 1827 | tone 1828 | tongue 1829 | tonight 1830 | tool 1831 | tooth 1832 | top 1833 | topic 1834 | topple 1835 | torch 1836 | tornado 1837 | tortoise 1838 | toss 1839 | total 1840 | tourist 1841 | toward 1842 | tower 1843 | town 1844 | toy 1845 | track 1846 | trade 1847 | traffic 1848 | tragic 1849 | train 1850 | transfer 1851 | trap 1852 | trash 1853 | travel 1854 | tray 1855 | treat 1856 | tree 1857 | trend 1858 | trial 1859 | tribe 1860 | trick 1861 | trigger 1862 | trim 1863 | trip 1864 | trophy 1865 | trouble 1866 | truck 1867 | true 1868 | truly 1869 | trumpet 1870 | trust 1871 | truth 1872 | try 1873 | tube 1874 | tuition 1875 | tumble 1876 | tuna 1877 | tunnel 1878 | turkey 1879 | turn 1880 | turtle 1881 | twelve 1882 | twenty 1883 | twice 1884 | twin 1885 | twist 1886 | two 1887 | type 1888 | typical 1889 | ugly 1890 | umbrella 1891 | unable 1892 | unaware 1893 | uncle 1894 | uncover 1895 | under 1896 | undo 1897 | unfair 1898 | unfold 1899 | unhappy 1900 | uniform 1901 | unique 1902 | unit 1903 | universe 1904 | unknown 1905 | unlock 1906 | until 1907 | unusual 1908 | unveil 1909 | update 1910 | upgrade 1911 | uphold 1912 | upon 1913 | upper 1914 | upset 1915 | urban 1916 | urge 1917 | usage 1918 | use 1919 | used 1920 | useful 1921 | useless 1922 | usual 1923 | utility 1924 | vacant 1925 | vacuum 1926 | vague 1927 | valid 1928 | valley 1929 | valve 1930 | van 1931 | vanish 1932 | vapor 1933 | various 1934 | vast 1935 | vault 1936 | vehicle 1937 | velvet 1938 | vendor 1939 | venture 1940 | venue 1941 | verb 1942 | verify 1943 | version 1944 | very 1945 | vessel 1946 | veteran 1947 | viable 1948 | vibrant 1949 | vicious 1950 | victory 1951 | video 1952 | view 1953 | village 1954 | vintage 1955 | violin 1956 | virtual 1957 | virus 1958 | visa 1959 | visit 1960 | visual 1961 | vital 1962 | vivid 1963 | vocal 1964 | voice 1965 | void 1966 | volcano 1967 | volume 1968 | vote 1969 | voyage 1970 | wage 1971 | wagon 1972 | wait 1973 | walk 1974 | wall 1975 | walnut 1976 | want 1977 | warfare 1978 | warm 1979 | warrior 1980 | wash 1981 | wasp 1982 | waste 1983 | water 1984 | wave 1985 | way 1986 | wealth 1987 | weapon 1988 | wear 1989 | weasel 1990 | weather 1991 | web 1992 | wedding 1993 | weekend 1994 | weird 1995 | welcome 1996 | west 1997 | wet 1998 | whale 1999 | what 2000 | wheat 2001 | wheel 2002 | when 2003 | where 2004 | whip 2005 | whisper 2006 | wide 2007 | width 2008 | wife 2009 | wild 2010 | will 2011 | win 2012 | window 2013 | wine 2014 | wing 2015 | wink 2016 | winner 2017 | winter 2018 | wire 2019 | wisdom 2020 | wise 2021 | wish 2022 | witness 2023 | wolf 2024 | woman 2025 | wonder 2026 | wood 2027 | wool 2028 | word 2029 | work 2030 | world 2031 | worry 2032 | worth 2033 | wrap 2034 | wreck 2035 | wrestle 2036 | wrist 2037 | write 2038 | wrong 2039 | yard 2040 | year 2041 | yellow 2042 | you 2043 | young 2044 | youth 2045 | zebra 2046 | zero 2047 | zone 2048 | zoo -------------------------------------------------------------------------------- /training/dataset_creator/sd_extractor.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline 2 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 3 | import torch 4 | import gc 5 | import os 6 | import nltk 7 | import random 8 | import unicodedata 9 | import re 10 | import argparse 11 | from nltk.corpus import stopwords 12 | 13 | file_path = os.path.abspath(os.path.dirname(__file__)) 14 | 15 | def parse_args(args=None): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--pretrained_model_name_or_path", type=str, default="stabilityai/stable-diffusion-2-1-base") 18 | parser.add_argument("--output_folder", type=str, default=os.path.join(file_path, "sd_extracted")) 19 | parser.add_argument("--width", type=int, default=512) 20 | parser.add_argument("--height", type=int, default=512) 21 | return parser.parse_args(args) 22 | 23 | def slugify(value): 24 | """ 25 | Converts to lowercase, removes non-word characters (alphanumerics and 26 | underscores) and converts spaces to hyphens. Also strips leading and 27 | trailing whitespace. 28 | """ 29 | value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii') 30 | value = re.sub('[^\w\s-]', '', value).strip().lower() 31 | return re.sub('[-\s]+', '-', value) 32 | 33 | imagenet_templates_small = [ 34 | "{}, realistic photo", 35 | "{}, realistic render", 36 | "{}, painting", 37 | "{}, anime", 38 | "{}, greg ruthkowski", 39 | "{}, cartoon", 40 | "{}, vector art", 41 | "{}, clip art" 42 | ] 43 | 44 | if __name__ == "__main__": 45 | args = parse_args() 46 | nltk.download('stopwords') 47 | model_id_or_path = args.pretrained_model_name_or_path 48 | pipeline = StableDiffusionPipeline.from_pretrained( 49 | model_id_or_path, 50 | revision="fp16", 51 | torch_dtype=torch.float16, 52 | ) 53 | tokenizer = CLIPTokenizer.from_pretrained( 54 | model_id_or_path, 55 | subfolder="tokenizer", 56 | ) 57 | text_encoder = CLIPTextModel.from_pretrained( 58 | model_id_or_path, subfolder="text_encoder" 59 | ) 60 | token_embeds = text_encoder.get_input_embeddings().weight.data 61 | def dummy(images, **kwargs): 62 | return images, False 63 | # pipeline.safety_checker = dummy 64 | pipeline = pipeline.to("cuda") 65 | stopwords_english = stopwords.words('english') 66 | tokens_to_search = [] 67 | 68 | common_english_words = {} 69 | with open(os.path.join(file_path, "bip39.txt"), 'r') as f: 70 | lines = f.readlines() 71 | for line in lines: 72 | if len(line) > 0: 73 | common_english_words[line.strip()] = True 74 | 75 | for token_id in range(token_embeds.shape[0]): 76 | token_name = tokenizer.decode(token_id) 77 | token_id = tokenizer.encode(token_name, add_special_tokens=False) 78 | if len(token_id) > 1: 79 | continue 80 | 81 | if len(token_name) > 3 and token_name.isalnum() and not token_name in stopwords_english and token_name in common_english_words: 82 | tokens_to_search.append(token_name) 83 | 84 | random.seed(80085) 85 | random.shuffle(tokens_to_search) 86 | 87 | for token_idx, token_name in enumerate(tokens_to_search): 88 | token_id = tokenizer.encode(token_name, add_special_tokens=False) 89 | if len(token_id) > 1: 90 | raise Exception("Need single token!") 91 | token_type = "train" 92 | if len(tokens_to_search) - token_idx <= 4: 93 | token_type = "val" 94 | image_output_folder = os.path.join(args.output_folder, token_type, token_name) 95 | if os.path.exists(image_output_folder): 96 | print(f"Skipping {image_output_folder} because it already exists") 97 | continue 98 | learned_embeds = token_embeds[token_id][0] 99 | concept_images_folder = os.path.join(image_output_folder, 'concept_images') 100 | os.makedirs(concept_images_folder, exist_ok = True) 101 | learned_embeds_dict = {token_name: learned_embeds.detach().cpu()} 102 | torch.save(learned_embeds_dict, os.path.join(image_output_folder, "learned_embeds.bin")) 103 | images_per_prompt = 2 104 | for image_idx in range(images_per_prompt): 105 | for text in imagenet_templates_small: 106 | text = text.format(token_name) 107 | print(f"Doing {token_name} with prompt: '{text}'...") 108 | image = pipeline( 109 | text, 110 | num_inference_steps=50, 111 | guidance_scale=9, 112 | width=args.width, 113 | height=args.height 114 | ).images[0] 115 | image.save(os.path.join(concept_images_folder, f"image_{slugify(text)}_{image_idx}.png")) 116 | -------------------------------------------------------------------------------- /training/get_extrema.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | def get_min_weights(embed_model, current) -> float: 5 | min_model = min(embed_model).item() 6 | if current is None: 7 | return min_model 8 | return min(current, min_model) 9 | 10 | def get_max_weights(embed_model, current) -> float: 11 | max_model = max(embed_model).item() 12 | if current is None: 13 | return max_model 14 | return max(current, max_model) 15 | 16 | def get_extrema(data_loader): 17 | min_weight = None 18 | max_weight = None 19 | 20 | for _, embed_batch in data_loader: 21 | for i in range(embed_batch.shape[0]): 22 | embed_model = embed_batch[i] 23 | min_weight = get_min_weights(embed_model, min_weight) 24 | max_weight = get_max_weights(embed_model, max_weight) 25 | return min_weight, max_weight -------------------------------------------------------------------------------- /training/lora_dataset_creator/create_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | from clip_retrieval.clip_client import ClipClient, Modality 4 | import os 5 | import pathlib 6 | from pebble import ProcessPool, ThreadPool 7 | from concurrent.futures import TimeoutError 8 | import sys 9 | import colorsys 10 | import traceback 11 | import os 12 | import cv2 13 | import random 14 | import re 15 | import numpy as np 16 | import requests 17 | import io 18 | import time 19 | import math 20 | from PIL import Image 21 | import torch 22 | import torchvision 23 | import torch.nn.functional as F 24 | import torchvision.transforms.functional as TF 25 | import multiprocessing 26 | import subprocess 27 | 28 | file_path = os.path.abspath(os.path.dirname(__file__)) 29 | prompts = [] 30 | 31 | with open(os.path.join(file_path, "lora_words.txt"), 'r') as f: 32 | lines = f.readlines() 33 | for line in lines: 34 | if len(line) > 0: 35 | prompts.append(line.strip()) 36 | 37 | def estimate_noise(img_tensor): 38 | greyscaler = torchvision.transforms.Grayscale() 39 | img_tensor = greyscaler(img_tensor) 40 | W, H = img_tensor.squeeze().shape 41 | K = torch.tensor( 42 | [ 43 | [ 1, -2, 1], 44 | [-2, 4, -2], 45 | [ 1, -2, 1] 46 | ], 47 | device = img_tensor.device, 48 | dtype = torch.float32 49 | ).unsqueeze(0).unsqueeze(0) 50 | torch_conv = F.conv2d(img_tensor, K, bias=None, stride=(1, 1), padding=1) 51 | sigma = torch.sum(torch.abs(torch_conv)) 52 | sigma = sigma * math.sqrt(0.5 * math.pi) / (6 * (W - 2) * (H - 2)) 53 | return sigma 54 | 55 | def variance_of_laplacian(image): 56 | # compute the Laplacian of the image and then return the focus 57 | # measure, which is simply the variance of the Laplacian 58 | return cv2.Laplacian(image, cv2.CV_64F).var() 59 | 60 | def image_filter(img: Image, skip = []) -> str: 61 | if not 'size' in skip: 62 | min_size = 128 63 | if img.size[0] < min_size or img.size[1] < min_size: 64 | return f"Size too small! ({img.size[0]}x{img.size[1]})" 65 | sigma = estimate_noise(TF.to_tensor(img)) 66 | if sigma > 0.1: 67 | return f"Skipped as image is too noisy! (sigma: {sigma})" 68 | if not 'lapvar' in skip: 69 | lapvar = variance_of_laplacian(np.array(img)[:, :, ::-1]) 70 | if lapvar < 50: 71 | return f"Skipped as image is too blurry! (lapvar: {lapvar})" 72 | return None 73 | 74 | def task_done(future): 75 | try: 76 | result = future.result() # blocks until results are ready 77 | except TimeoutError as error: 78 | print("Function took longer than %d seconds" % error.args[1]) 79 | except Exception as error: 80 | print("Function raised %s" % error) 81 | print(error.traceback) # traceback of the function 82 | 83 | def download_image_from_row_worker(prompt: str, row, count: int, images_folder, headers): 84 | try: 85 | image_name = f"{prompt}_{count}" 86 | req = requests.get(row['url'], headers=headers) 87 | with Image.open(io.BytesIO(req.content)) as img: 88 | img = img.convert('RGB') 89 | filter_result = image_filter(img) 90 | if filter_result != None: 91 | print(f"Image filtered: {filter_result}. [{row['url']}]") 92 | return 93 | min_size = min(img.size[0], img.size[1]) 94 | image_path = os.path.join(images_folder, f"{image_name}.png") 95 | img.save(image_path) 96 | print(f"Saved {image_path}") 97 | except KeyboardInterrupt: 98 | print('KeyboardInterrupt exception is caught, stopping') 99 | return 100 | 101 | def download_images(prompt, images_folder): 102 | images_folder = os.path.join(images_folder, prompt, "images") 103 | 104 | if not os.path.exists(images_folder): 105 | os.makedirs(images_folder) 106 | 107 | client = ClipClient(url="https://knn5.laion.ai/knn-service", indice_name="laion5B", num_images=100, aesthetic_weight=0.2) 108 | result = client.query(text=prompt) 109 | result = list(filter(lambda item: item['url'].endswith(".png") or item['url'].endswith(".jpg") or item['url'].endswith(".webp"), result)) 110 | print(f"Making training database for {prompt}. {len(result)} candidates") 111 | headers = { 112 | 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/103.0.0.0 Safari/537.36', 113 | } 114 | count = 0 115 | with ProcessPool(max_workers=4, max_tasks=len(result)) as pool: 116 | try: 117 | for row in result: 118 | future = pool.schedule(download_image_from_row_worker, args=(prompt, row, count, images_folder, headers), timeout=60) 119 | future.add_done_callback(task_done) 120 | count += 1 121 | except KeyboardInterrupt: 122 | print("Keyboard interrupt, closing pool") 123 | pool.close() 124 | pool.stop() 125 | 126 | def main(): 127 | images_folder = os.path.join(file_path, "lora_dataset") 128 | for prompt in prompts: 129 | download_images(prompt, images_folder) 130 | 131 | if __name__ == "__main__": 132 | main() -------------------------------------------------------------------------------- /training/lora_dataset_creator/lora_words.txt: -------------------------------------------------------------------------------- 1 | tiger 2 | dragon 3 | toothless 4 | veemon 5 | fox 6 | blue fox 7 | ice fox 8 | mountain 9 | simpsons 10 | wallmart 11 | basket 12 | pikachu 13 | detective pikachu 14 | hunger games 15 | tabaluga 16 | lugia 17 | circus 18 | jungle 19 | bugatti chiron 20 | bitcoin -------------------------------------------------------------------------------- /training/lora_dataset_creator/split_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import argparse 4 | import time 5 | import random 6 | import shutil 7 | 8 | file_path = os.path.abspath(os.path.dirname(__file__)) 9 | 10 | def parse_args(args=None): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--input_folder", type=str, default=os.path.join(file_path, "lora_dataset")) 13 | parser.add_argument("--val_amount", type=int, default=4) 14 | return parser.parse_args(args) 15 | 16 | def move_files(input_folder, files, type: str): 17 | target_path = os.path.join(input_folder, type) 18 | for file in files: 19 | source_folder = os.path.join(input_folder, file) 20 | target_folder = os.path.join(target_path, file) 21 | shutil.move(source_folder, target_folder) 22 | 23 | def main(): 24 | args = parse_args() 25 | files = os.listdir(args.input_folder) 26 | if "val" in files and "train" in files: 27 | print("Already splitted up!") 28 | return 29 | random.shuffle(files) 30 | val_files = files[:args.val_amount] 31 | train_files = files[args.val_amount:] 32 | move_files(args.input_folder, val_files, "val") 33 | move_files(args.input_folder, train_files, "train") 34 | 35 | if __name__ == "__main__": 36 | main() -------------------------------------------------------------------------------- /training/lora_dataset_creator/train_loras.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import argparse 4 | import time 5 | 6 | file_path = os.path.abspath(os.path.dirname(__file__)) 7 | 8 | def parse_args(args=None): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--pretrained_model_name_or_path", type=str, default="stabilityai/stable-diffusion-2-1-base") 11 | parser.add_argument("--input_folder", type=str, default=os.path.join(file_path, "lora_dataset")) 12 | parser.add_argument("--width", type=int, default=512) 13 | parser.add_argument("--height", type=int, default=512) 14 | return parser.parse_args(args) 15 | 16 | def main(): 17 | args = parse_args() 18 | image_folders = os.listdir(args.input_folder) 19 | for image_folder in image_folders: 20 | print(f"Processing: {image_folder}") 21 | output_path = os.path.join(args.input_folder, image_folder, "models") 22 | image_folder = os.path.join(args.input_folder, image_folder, "images") 23 | os.makedirs(output_path, exist_ok=True) 24 | 25 | cmd = ['lora_pti', 26 | '--pretrained_model_name_or_path=' + args.pretrained_model_name_or_path, 27 | '--instance_data_dir=' + image_folder, 28 | '--output_dir=' + output_path, 29 | '--train_text_encoder', 30 | '--resolution=512', 31 | '--train_batch_size=1', 32 | '--gradient_accumulation_steps=4', 33 | '--scale_lr', 34 | '--learning_rate_unet=1e-4', 35 | '--learning_rate_text=1e-5', 36 | '--learning_rate_ti=5e-4', 37 | '--color_jitter', 38 | '--lr_scheduler="linear"', 39 | '--lr_warmup_steps=0', 40 | '--placeholder_tokens=""', 41 | '--use_template="object"', 42 | '--save_steps=100', 43 | '--max_train_steps_ti=1000', 44 | '--max_train_steps_tuning=1000', 45 | '--perform_inversion=True', 46 | '--clip_ti_decay', 47 | '--weight_decay_ti=0.000', 48 | '--weight_decay_lora=0.001', 49 | '--continue_inversion', 50 | '--continue_inversion_lr=1e-4', 51 | '--device="cuda:0"', 52 | '--lora_rank=1'] 53 | 54 | subprocess.run(cmd) 55 | time.sleep(2) 56 | 57 | if __name__ == "__main__": 58 | main() -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import torch 4 | import torchvision 5 | from torchvision import transforms 6 | import pytorch_lightning as pl 7 | import torchmetrics 8 | import os 9 | from tqdm import tqdm 10 | from imgaug import augmenters as iaa 11 | import numpy as np 12 | import random 13 | from functools import partial 14 | from torch.utils.data import Dataset, DataLoader 15 | from PIL import Image 16 | from PIL import ImageOps 17 | from pytorch_lightning.callbacks import LearningRateMonitor 18 | import traceback 19 | import sys 20 | from get_extrema import get_extrema 21 | from leap_sd import LM 22 | 23 | def parse_args(args=None): 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--batch_size", type=int, default=4) 26 | parser.add_argument("--learning_rate", type=float, default=1e-5) 27 | parser.add_argument("--weight_decay", type=float, default=0.0001) 28 | parser.add_argument("--logging", type=str, default="tensorboard") 29 | parser.add_argument("--latent_dim_size", type=int, default=1024) 30 | parser.add_argument("--dropout_p", type=float, default=0.01) 31 | file_path = os.path.abspath(os.path.dirname(__file__)) 32 | parser.add_argument("--dataset_path", type=str, default=os.path.join(file_path, "dataset_creator/sd_extracted")) 33 | parser = pl.Trainer.add_argparse_args(parser) 34 | return parser.parse_args(args) 35 | 36 | def get_datamodule(path: str, batch_size: int): 37 | train_transforms = transforms.Compose( 38 | [ 39 | iaa.Resize({"shorter-side": (128, 256), "longer-side": "keep-aspect-ratio"}).augment_image, 40 | iaa.CropToFixedSize(width=128, height=128).augment_image, 41 | iaa.Sometimes(0.8, iaa.Sequential([ 42 | iaa.flip.Fliplr(p=0.5), 43 | iaa.flip.Flipud(p=0.5), 44 | iaa.Sometimes( 45 | 0.5, 46 | iaa.Sequential([ 47 | iaa.ShearX((-20, 20)), 48 | iaa.ShearY((-20, 20)) 49 | ]) 50 | ), 51 | iaa.GaussianBlur(sigma=(0.0, 0.05)), 52 | iaa.MultiplyBrightness(mul=(0.65, 1.35)), 53 | iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05*255), per_channel=0.5), 54 | ], random_order=True)).augment_image, 55 | np.copy, 56 | transforms.ToTensor(), 57 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 58 | ] 59 | ) 60 | test_transforms = transforms.Compose( 61 | [ 62 | iaa.Resize({"shorter-side": (128, 256), "longer-side": "keep-aspect-ratio"}).augment_image, 63 | iaa.CropToFixedSize(width=128, height=128).augment_image, 64 | transforms.ToTensor(), 65 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 66 | ] 67 | ) 68 | 69 | class ImageWeightDataset(Dataset): 70 | def __init__(self, path, transform): 71 | self.path = path 72 | self.files = os.listdir(self.path) 73 | self.transform = transform 74 | self.num_images = 4 75 | 76 | def __getitem__(self, index): 77 | full_path = os.path.join(self.path, self.files[index]) 78 | try: 79 | images_path = os.path.join(full_path, "concept_images") 80 | image_names = os.listdir(images_path) 81 | random.shuffle(image_names) 82 | image_names = image_names[:random.randint(1, self.num_images)] 83 | image_names_len = len(image_names) 84 | if image_names_len < self.num_images: 85 | for i in range(self.num_images - image_names_len): 86 | image_names.append(image_names[i % image_names_len]) 87 | 88 | images = None 89 | for image_name in image_names: 90 | image = Image.open(os.path.join(images_path, image_name)).convert("RGB") 91 | image = ImageOps.exif_transpose(image) 92 | image = self.transform(np.array(image)).unsqueeze(0) 93 | if images is None: 94 | images = image 95 | else: 96 | images = torch.cat((images, image), 0) 97 | 98 | loaded_learned_embeds = torch.load(os.path.join(full_path, "learned_embeds.bin"), map_location="cpu") 99 | embed_model = loaded_learned_embeds[list(loaded_learned_embeds.keys())[0]].detach() 100 | embed_model = embed_model.to(torch.float32) 101 | return images, embed_model 102 | except: 103 | print(f"Error with {full_path}!") 104 | traceback.print_exception(*sys.exc_info()) 105 | 106 | def __len__(self): 107 | return len(self.files) 108 | 109 | class ImageWeights(pl.LightningDataModule): 110 | def __init__(self, data_folder: str, batch_size: int): 111 | super().__init__() 112 | self.num_workers = 16 113 | self.data_folder = data_folder 114 | self.batch_size = batch_size 115 | self.overfit = False 116 | self.num_samples = len(os.listdir(os.path.join(self.data_folder, "train"))) 117 | if self.overfit: 118 | self.num_samples = 250 119 | 120 | def prepare_data(self): 121 | pass 122 | 123 | def setup(self, stage): 124 | pass 125 | 126 | def train_dataloader(self): 127 | dataset = ImageWeightDataset(os.path.join(self.data_folder, "train"), transform = train_transforms) 128 | if self.overfit: 129 | file_list = dataset.files[:1] 130 | print("Overfit! Using only:", file_list) 131 | dataset.files = file_list * 250 132 | return DataLoader(dataset, num_workers = self.num_workers, batch_size = self.batch_size, shuffle=True) 133 | 134 | def val_dataloader(self): 135 | return DataLoader(ImageWeightDataset(os.path.join(self.data_folder, "val"), transform = test_transforms), num_workers = self.num_workers, batch_size = self.batch_size) 136 | 137 | def test_dataloader(self): 138 | return DataLoader(ImageWeightDataset(os.path.join(self.data_folder, "test"), transform = test_transforms), num_workers = self.num_workers, batch_size = self.batch_size) 139 | 140 | def teardown(self, stage): 141 | # clean up after fit or test 142 | # called on every process in DDP 143 | pass 144 | 145 | dm = ImageWeights(path, batch_size = batch_size) 146 | 147 | return dm 148 | 149 | if __name__ == "__main__": 150 | torch.autograd.set_detect_anomaly(True) 151 | pl.seed_everything(1) 152 | args = parse_args() 153 | 154 | # Add some dm attributes to args Namespace 155 | args.image_size = 128 156 | args.patch_size = 32 157 | args.input_shape = (3, 128, 128) 158 | 159 | # compute total number of steps 160 | batch_size = args.batch_size * args.gpus if args.gpus > 0 else args.batch_size 161 | 162 | dm = get_datamodule(batch_size = batch_size, path = args.dataset_path) 163 | 164 | min_weight, max_weight = get_extrema(dm.train_dataloader()) 165 | print(f"Extrema of entire training set: {min_weight} <> {max_weight}") 166 | args.min_weight = min_weight 167 | args.max_weight = max_weight 168 | 169 | args.steps = dm.num_samples // batch_size * args.max_epochs 170 | 171 | # Init Lightning Module 172 | lm = LM(**vars(args)) 173 | lm.train() 174 | 175 | # Init callbacks 176 | if args.logging != "none": 177 | lr_monitor = LearningRateMonitor(logging_interval='step') 178 | args.callbacks = [lr_monitor] 179 | if args.logging == "wandb": 180 | from pytorch_lightning.loggers import WandbLogger 181 | args.logger = WandbLogger(project="LEAP") 182 | else: 183 | args.checkpoint_callback = False 184 | args.logger = False 185 | 186 | # Set up Trainer 187 | trainer = pl.Trainer.from_argparse_args(args) 188 | 189 | # Train! 190 | trainer.fit(lm, dm) 191 | -------------------------------------------------------------------------------- /training/train_lora.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import torch 4 | import torchvision 5 | from torchvision import transforms 6 | import pytorch_lightning as pl 7 | import torchmetrics 8 | import os 9 | from tqdm import tqdm 10 | from imgaug import augmenters as iaa 11 | import numpy as np 12 | import random 13 | from functools import partial 14 | from torch.utils.data import Dataset, DataLoader 15 | from PIL import Image 16 | from PIL import ImageOps 17 | from pytorch_lightning.callbacks import LearningRateMonitor 18 | import traceback 19 | import sys 20 | from get_extrema import get_extrema 21 | from leap_sd import LM 22 | from safetensors import safe_open 23 | 24 | def parse_args(args=None): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--batch_size", type=int, default=4) 27 | parser.add_argument("--learning_rate", type=float, default=1e-5) 28 | parser.add_argument("--weight_decay", type=float, default=0.0001) 29 | parser.add_argument("--logging", type=str, default="tensorboard") 30 | parser.add_argument("--latent_dim_size", type=int, default=509248) 31 | parser.add_argument("--dropout_p", type=float, default=0.01) 32 | file_path = os.path.abspath(os.path.dirname(__file__)) 33 | parser.add_argument("--dataset_path", type=str, default=os.path.join(file_path, "lora_dataset_creator/lora_dataset")) 34 | parser = pl.Trainer.add_argparse_args(parser) 35 | return parser.parse_args(args) 36 | 37 | def get_datamodule(path: str, batch_size: int): 38 | train_transforms = transforms.Compose( 39 | [ 40 | iaa.Resize({"shorter-side": (128, 256), "longer-side": "keep-aspect-ratio"}).augment_image, 41 | iaa.CropToFixedSize(width=128, height=128).augment_image, 42 | iaa.Sometimes(0.8, iaa.Sequential([ 43 | iaa.flip.Fliplr(p=0.5), 44 | iaa.flip.Flipud(p=0.5), 45 | iaa.Sometimes( 46 | 0.5, 47 | iaa.Sequential([ 48 | iaa.ShearX((-20, 20)), 49 | iaa.ShearY((-20, 20)) 50 | ]) 51 | ), 52 | iaa.GaussianBlur(sigma=(0.0, 0.05)), 53 | iaa.MultiplyBrightness(mul=(0.65, 1.35)), 54 | iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05*255), per_channel=0.5), 55 | ], random_order=True)).augment_image, 56 | np.copy, 57 | transforms.ToTensor(), 58 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 59 | ] 60 | ) 61 | test_transforms = transforms.Compose( 62 | [ 63 | iaa.Resize({"shorter-side": (128, 256), "longer-side": "keep-aspect-ratio"}).augment_image, 64 | iaa.CropToFixedSize(width=128, height=128).augment_image, 65 | transforms.ToTensor(), 66 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 67 | ] 68 | ) 69 | 70 | class ImageWeightDataset(Dataset): 71 | def __init__(self, path, transform): 72 | self.path = path 73 | self.files = os.listdir(self.path) 74 | self.transform = transform 75 | self.num_images = 4 76 | 77 | def __getitem__(self, index): 78 | full_path = os.path.join(self.path, self.files[index]) 79 | try: 80 | images_path = os.path.join(full_path, "images") 81 | image_names = os.listdir(images_path) 82 | random.shuffle(image_names) 83 | image_names = image_names[:random.randint(1, self.num_images)] 84 | image_names_len = len(image_names) 85 | if image_names_len < self.num_images: 86 | for i in range(self.num_images - image_names_len): 87 | image_names.append(image_names[i % image_names_len]) 88 | 89 | images = None 90 | for image_name in image_names: 91 | image = Image.open(os.path.join(images_path, image_name)).convert("RGB") 92 | image = ImageOps.exif_transpose(image) 93 | image = self.transform(np.array(image)).unsqueeze(0) 94 | if images is None: 95 | images = image 96 | else: 97 | images = torch.cat((images, image), 0) 98 | 99 | model_path = os.path.join(full_path, "models") 100 | with safe_open(os.path.join(model_path, "step_1000.safetensors"), framework="pt") as f: 101 | tensor = None 102 | for k in f.keys(): 103 | if tensor is None: 104 | tensor = f.get_tensor(k).flatten() 105 | else: 106 | tensor = torch.cat((tensor, f.get_tensor(k).flatten()), 0) 107 | return images, tensor 108 | except: 109 | print(f"Error with {full_path}!") 110 | traceback.print_exception(*sys.exc_info()) 111 | 112 | def __len__(self): 113 | return len(self.files) 114 | 115 | class ImageWeights(pl.LightningDataModule): 116 | def __init__(self, data_folder: str, batch_size: int): 117 | super().__init__() 118 | self.num_workers = 16 119 | self.data_folder = data_folder 120 | self.batch_size = batch_size 121 | self.overfit = False 122 | self.num_samples = len(os.listdir(os.path.join(self.data_folder, "train"))) 123 | if self.overfit: 124 | self.num_samples = 250 125 | 126 | def prepare_data(self): 127 | pass 128 | 129 | def setup(self, stage): 130 | pass 131 | 132 | def train_dataloader(self): 133 | dataset = ImageWeightDataset(os.path.join(self.data_folder, "train"), transform = train_transforms) 134 | if self.overfit: 135 | file_list = dataset.files[:1] 136 | print("Overfit! Using only:", file_list) 137 | dataset.files = file_list * 250 138 | return DataLoader(dataset, num_workers = self.num_workers, batch_size = self.batch_size, shuffle=True) 139 | 140 | def val_dataloader(self): 141 | return DataLoader(ImageWeightDataset(os.path.join(self.data_folder, "val"), transform = test_transforms), num_workers = self.num_workers, batch_size = self.batch_size) 142 | 143 | def test_dataloader(self): 144 | return DataLoader(ImageWeightDataset(os.path.join(self.data_folder, "test"), transform = test_transforms), num_workers = self.num_workers, batch_size = self.batch_size) 145 | 146 | def teardown(self, stage): 147 | # clean up after fit or test 148 | # called on every process in DDP 149 | pass 150 | 151 | dm = ImageWeights(path, batch_size = batch_size) 152 | 153 | return dm 154 | 155 | if __name__ == "__main__": 156 | torch.autograd.set_detect_anomaly(True) 157 | pl.seed_everything(1) 158 | args = parse_args() 159 | 160 | # Add some dm attributes to args Namespace 161 | args.image_size = 128 162 | args.patch_size = 32 163 | args.input_shape = (3, 128, 128) 164 | 165 | # compute total number of steps 166 | batch_size = args.batch_size * args.gpus if args.gpus > 0 else args.batch_size 167 | 168 | dm = get_datamodule(batch_size = batch_size, path = args.dataset_path) 169 | 170 | min_weight, max_weight = get_extrema(dm.train_dataloader()) 171 | print(f"Extrema of entire training set: {min_weight} <> {max_weight}") 172 | args.min_weight = min_weight 173 | args.max_weight = max_weight 174 | 175 | args.steps = dm.num_samples // batch_size * args.max_epochs 176 | 177 | # Init Lightning Module 178 | lm = LM(**vars(args)) 179 | lm.train() 180 | 181 | # Init callbacks 182 | if args.logging != "none": 183 | lr_monitor = LearningRateMonitor(logging_interval='step') 184 | args.callbacks = [lr_monitor] 185 | if args.logging == "wandb": 186 | from pytorch_lightning.loggers import WandbLogger 187 | args.logger = WandbLogger(project="LEAP") 188 | else: 189 | args.checkpoint_callback = False 190 | args.logger = False 191 | 192 | # Set up Trainer 193 | trainer = pl.Trainer.from_argparse_args(args) 194 | 195 | # Train! 196 | trainer.fit(lm, dm) 197 | --------------------------------------------------------------------------------