├── LICENSE ├── code ├── Visual Text Sampling │ ├── text2img_sd.py │ ├── text2img_sdxl.py │ ├── sample_prompts_gpt.py │ └── sample_prompts_llama.py ├── others │ ├── blip_caption.py │ ├── train_text_to_image.py │ ├── train_text_to_image_sdxl.py │ └── train_dreambooth.py ├── VLEU Calculation │ ├── cal_vleu_openclip.py │ └── cal_vleu_clip.py └── Text-Image Scoring │ └── clip_score.py ├── README.md ├── .gitignore └── data └── Constrained Subject └── teddy_bear.json /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 nameless 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /code/Visual Text Sampling/text2img_sd.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline 2 | import argparse 3 | import torch 4 | import json 5 | import os 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--model', type=str) 11 | parser.add_argument('--prompt', type=str) 12 | parser.add_argument('--prompt_json_path', type=str) 13 | parser.add_argument('--output_dir', type=str) 14 | parser.add_argument('--num', type=int) 15 | parser.add_argument('--batch_size', type=int, default=1) 16 | args = parser.parse_args() 17 | 18 | pipe = StableDiffusionPipeline.from_pretrained(args.model, torch_dtype=torch.float16, safety_checker=None) 19 | pipe = pipe.to("cuda") 20 | 21 | if args.prompt_json_path: 22 | prompts = json.load(open(args.prompt_json_path,'r',encoding='utf-8')) 23 | 24 | os.makedirs(args.output_dir, exist_ok=True) 25 | 26 | for i in range(0, args.num, args.batch_size): 27 | size = min(args.batch_size, args.num - i) 28 | if args.prompt_json_path: 29 | images = pipe(prompts[i:i+size]).images 30 | else: 31 | images = pipe([args.prompt] * size).images 32 | for j, image in enumerate(images): 33 | image.save(f'{args.output_dir}/{i+j}.jpg') 34 | -------------------------------------------------------------------------------- /code/Visual Text Sampling/text2img_sdxl.py: -------------------------------------------------------------------------------- 1 | from diffusers import DiffusionPipeline 2 | import argparse 3 | import torch 4 | import json, os 5 | 6 | 7 | if __name__ == '__main__': 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument('--model', type=str) 10 | parser.add_argument('--prompt', type=str) 11 | parser.add_argument('--prompt_json_path', type=str) 12 | parser.add_argument('--output_dir', type=str) 13 | parser.add_argument('--start', type=int, default=0) 14 | parser.add_argument('--num', type=int) 15 | parser.add_argument('--batch_size', type=int, default=1) 16 | args = parser.parse_args() 17 | 18 | pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=torch.float16, safety_checker=None) 19 | pipe = pipe.to("cuda") 20 | 21 | if args.prompt_json_path: 22 | prompts = json.load(open(args.prompt_json_path,'r',encoding='utf-8')) 23 | 24 | os.makedirs(args.output_dir, exist_ok=True) 25 | 26 | end = args.start + args.num 27 | for i in range(args.start, end, args.batch_size): 28 | size = min(args.batch_size, end - i) 29 | if args.prompt_json_path: 30 | images = pipe(prompts[i:i+size]).images 31 | else: 32 | images = pipe([args.prompt] * size).images 33 | for j, image in enumerate(images): 34 | image.save(f'{args.output_dir}/{i+j}.jpg') 35 | -------------------------------------------------------------------------------- /code/others/blip_caption.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import random 4 | import fnmatch 5 | import argparse 6 | import torch 7 | from PIL import Image, ImageOps 8 | from transformers import BlipProcessor, BlipForConditionalGeneration 9 | import threading 10 | 11 | 12 | def blip_caption(model_path, image_paths, device): 13 | # Load the model and processor 14 | processor = BlipProcessor.from_pretrained(model_path) 15 | model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16).to(device) 16 | 17 | for img_path in image_paths: 18 | if os.path.exists(img_path + '.txt'): 19 | continue 20 | # Open and process the image 21 | raw_image = ImageOps.exif_transpose(Image.open(img_path).convert('RGB')) 22 | 23 | # Conditional image captioning 24 | inputs = processor(raw_image, text='this is', return_tensors="pt").to(device, torch.float16) 25 | 26 | out = model.generate(**inputs) 27 | caption = processor.decode(out[0], skip_special_tokens=True) 28 | 29 | caption = re.sub(r'^this is\s*','',caption).capitalize() 30 | print(caption) 31 | 32 | # Write the caption to a .txt file 33 | with open(img_path + '.txt', 'w') as f: 34 | f.write(caption) 35 | 36 | def random_partition(input_list, num_partitions): 37 | random.shuffle(input_list) 38 | list_length = len(input_list) 39 | partition_size = list_length // num_partitions 40 | random_partitioned_lists = [input_list[i * partition_size: (i + 1) * partition_size] for i in range(num_partitions)] 41 | return random_partitioned_lists 42 | 43 | def find_all_images(root_dir): 44 | img_files = [] 45 | for foldername, subfolders, filenames in os.walk(root_dir): 46 | for extension in ['*.jpg', '*.jpeg', '*.png']: 47 | for filename in fnmatch.filter(filenames, extension): 48 | file_path = os.path.join(foldername, filename) 49 | img_files.append(file_path) 50 | return img_files 51 | 52 | 53 | if __name__ == '__main__': 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument('model_path', type=str, default='Salesforce/blip-image-captioning-large') 56 | parser.add_argument('data_dir', type=str) 57 | args = parser.parse_args() 58 | 59 | image_paths = find_all_images(args.data_dir) 60 | threads = [] 61 | random.shuffle(image_paths) 62 | threads.append(threading.Thread(target=blip_caption, args=(args.model_path, image_paths, 'cuda'))) 63 | 64 | for thread in threads: 65 | thread.start() 66 | for thread in threads: 67 | thread.join() 68 | -------------------------------------------------------------------------------- /code/VLEU Calculation/cal_vleu_openclip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from PIL import Image 8 | from tqdm import tqdm 9 | import open_clip 10 | 11 | 12 | device = 'cuda' 13 | 14 | def calculate_vleu(image_dir, prompts, temperature=0.01): 15 | with torch.no_grad(): 16 | text_embs = [] 17 | for prompt in tqdm(prompts): 18 | inputs = tokenizer([prompt]).to(device) 19 | outputs = model.encode_text(inputs) 20 | outputs /= outputs.norm(dim=-1, keepdim=True) 21 | text_embs.append(outputs) 22 | 23 | img_embs = [] 24 | img_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)] 25 | for img_path in tqdm(img_paths): 26 | image = Image.open(img_path) 27 | inputs = preprocess(image).unsqueeze(0).to(device) 28 | outputs = model.encode_image(inputs) 29 | outputs /= outputs.norm(dim=-1, keepdim=True) 30 | img_embs.append(outputs) 31 | 32 | prob_matrix = [] 33 | for i in range(len(img_embs)): 34 | cosine_sim = [] 35 | for j in range(len(text_embs)): 36 | cosine_sim.append(img_embs[i] @ text_embs[j].T) 37 | prob = F.softmax(torch.tensor(cosine_sim) / temperature, dim=0) 38 | prob_matrix.append(prob) 39 | 40 | prob_matrix = torch.stack(prob_matrix) 41 | 42 | # marginal distribution for text embeddings 43 | text_emb_marginal_distribution = prob_matrix.sum(axis=0) / prob_matrix.shape[0] 44 | 45 | # KL divergence for each image 46 | image_kl_divergences = [] 47 | for i in range(prob_matrix.shape[0]): 48 | kl_divergence = (prob_matrix[i, :] * torch.log(prob_matrix[i, :] / text_emb_marginal_distribution)).sum().item() 49 | image_kl_divergences.append(kl_divergence) 50 | 51 | vleu_score = np.exp(sum(image_kl_divergences) / prob_matrix.shape[0]) 52 | return vleu_score 53 | 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--model', type=str) 58 | parser.add_argument('--pretrained', type=str) 59 | parser.add_argument('--prompt_json_path', type=str) 60 | parser.add_argument('--image_dir', type=str) 61 | args = parser.parse_args() 62 | 63 | model, _, preprocess = open_clip.create_model_and_transforms(args.model, pretrained=args.pretrained) 64 | tokenizer = open_clip.get_tokenizer(args.model) 65 | model = model.to(device) 66 | 67 | with open(args.prompt_json_path, 'r', encoding='utf-8') as f: 68 | prompts = json.load(f) 69 | 70 | score = calculate_vleu(args.image_dir, prompts, 0.01) 71 | print(score) 72 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VLEU: a Method for Automatic Evaluation for Generalizability of Text-to-Image Models 2 | 3 | This repository contains the code and data for the paper "VLEU: a Method for Automatic Evaluation for Generalizability of Text-to-Image Models" presented at EMNLP 2024. 4 | 5 | ## Files 6 | 7 | - `code/`: Contains the Python scripts and notebooks used for the project. 8 | - `data/`: Contains the datasets used for training and evaluation. 9 | 10 | ## Quick Start 11 | 12 | ### Sampling Prompts 13 | 14 | #### 1. GPT 15 | 16 | `code/Visual Text Sampling/sample_prompts_gpt.py` 17 | ```sh 18 | python sample_prompts_gpt.py \ 19 | --api_key YOUR_KEY \ 20 | --n_prompts 1000 \ 21 | --output prompts.json \ 22 | --n_threads 50 \ 23 | --step 30 \ 24 | --key_word dog 25 | ``` 26 | 27 | #### 2. LLaMA 28 | 29 | `code/Visual Text Sampling/sample_prompts_llama.py` 30 | ```sh 31 | python sample_prompts_llama.py \ 32 | --model meta-llama/Llama-2-13b-chat-hf \ 33 | --n_prompts 1000 \ 34 | --output prompts.json \ 35 | --num_return_sequences 2 36 | ``` 37 | 38 | ### T2I Generation 39 | 40 | #### 1. Stable Diffusion 41 | 42 | `code/Visual Text Sampling/text2img_sd.py` 43 | ```sh 44 | python text2img_sd.py \ 45 | --model stabilityai/stable-diffusion-2-1 \ 46 | --prompt_json_path prompts.json \ 47 | --output_dir image_output \ 48 | --num 1000 \ 49 | --batch_size 4 50 | ``` 51 | 52 | #### 2. Stable Diffusion XL 53 | 54 | `code/Visual Text Sampling/text2img_sdxl.py` 55 | ```sh 56 | python text2img_sdxl.py \ 57 | --model stabilityai/stable-diffusion-xl-base-1.0 \ 58 | --prompt_json_path prompts.json \ 59 | --output_dir image_output \ 60 | --num 1000 \ 61 | --batch_size 4 62 | ``` 63 | 64 | ### VLEU Calculation 65 | 66 | #### 1. CLIP 67 | 68 | `code/VLEU Calculation/cal_vleu_clip.py` 69 | ```sh 70 | python cal_vleu_clip.py \ 71 | --model openai/clip-vit-base-patch16 \ 72 | --prompt_json_path prompts.json \ 73 | --image_dir image_output 74 | ``` 75 | 76 | #### 2. OpenCLIP 77 | 78 | `code/VLEU Calculation/cal_vleu_openclip.py` 79 | ```sh 80 | python cal_vleu_openclip.py \ 81 | --model ViT-L-14 \ 82 | --pretrained open_clip_pytorch_model.bin \ 83 | --prompt_json_path prompts.json \ 84 | --image_dir image_output 85 | ``` 86 | 87 | ## Citation 88 | 89 | If you use this code or data in your research, please cite our paper: 90 | 91 | ```bibtex 92 | @misc{cao2024vleumethodautomaticevaluation, 93 | title={VLEU: a Method for Automatic Evaluation for Generalizability of Text-to-Image Models}, 94 | author={Jingtao Cao and Zheng Zhang and Hongru Wang and Kam-Fai Wong}, 95 | year={2024}, 96 | eprint={2409.14704}, 97 | archivePrefix={arXiv}, 98 | primaryClass={cs.CV}, 99 | url={https://arxiv.org/abs/2409.14704}, 100 | } 101 | ``` -------------------------------------------------------------------------------- /code/Visual Text Sampling/sample_prompts_gpt.py: -------------------------------------------------------------------------------- 1 | import os, json 2 | import argparse 3 | from threading import Thread 4 | from langchain.schema import HumanMessage, SystemMessage, AIMessage 5 | from langchain.chat_models import ChatOpenAI 6 | 7 | 8 | def get_prompts(num, key_word=None, property=None): 9 | for i in range(0, num, args.step): 10 | if key_word and property: 11 | system_input = f'Please imagine a picture of {key_word} and describe it in one sentence, making sure to include the word "{key_word}" and words about {property}.' 12 | elif key_word: 13 | system_input = f'Please imagine a picture of {key_word} and describe it in one sentence, making sure to include the word "{key_word}".' 14 | else: 15 | system_input = 'Please imagine a random picture and describe it in one sentence.' 16 | human_input = [ 17 | SystemMessage(content=system_input) 18 | ] 19 | ai_output = llm(human_input) 20 | n = 0 21 | limit = min(args.step, num-i) 22 | while True: 23 | human_input.append(AIMessage(content=ai_output.content)) 24 | human_input.append(HumanMessage(content='Again')) 25 | ai_output = llm(human_input) 26 | while key_word and key_word not in ai_output.content: 27 | ai_output = llm(human_input) 28 | prompts.append(ai_output.content) 29 | print(ai_output.content) 30 | n += 1 31 | if n >= limit: 32 | break 33 | 34 | 35 | if __name__ == '__main__': 36 | parser = argparse.ArgumentParser() 37 | parser.add_argument('--api_key', type=str) 38 | parser.add_argument('--api_base', type=str, default=None) 39 | parser.add_argument('--n_prompts', type=int) 40 | parser.add_argument('--key_word', type=str, default=None) 41 | parser.add_argument('--output', type=str) 42 | parser.add_argument('--model_name', type=str, default='gpt-3.5-turbo') 43 | parser.add_argument('--temperature', type=float, default=0.3) 44 | parser.add_argument('--n_threads', type=int, default=1) 45 | parser.add_argument('--step', type=int, default=50) 46 | args = parser.parse_args() 47 | # global environment 48 | os.environ['OPENAI_API_KEY'] = args.api_key 49 | if args.api_base: 50 | os.environ['OPENAI_API_BASE'] = args.api_base 51 | # llm initialization 52 | llm = ChatOpenAI(model_name=args.model_name, temperature=args.temperature) 53 | # get prompts 54 | prompts = [] 55 | threads = [] 56 | step = args.n_prompts // args.n_threads 57 | for i in range(0, args.n_prompts, step): 58 | num = min(step, args.n_prompts-i) 59 | threads.append(Thread(target=get_prompts, args=(step, args.key_word))) 60 | for t in threads: 61 | t.start() 62 | for t in threads: 63 | t.join() 64 | with open(args.output,'w',encoding='utf-8') as f: 65 | json.dump(prompts, f) 66 | -------------------------------------------------------------------------------- /code/Visual Text Sampling/sample_prompts_llama.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | import torch 4 | import transformers 5 | from transformers import AutoTokenizer 6 | from tqdm import tqdm 7 | 8 | 9 | def get_prompt(prompt, num_return_sequences): 10 | prompt_len = len(prompt) 11 | sequences = pipeline( 12 | prompt, 13 | do_sample=True, 14 | temperature=0.3, 15 | top_k=10, 16 | num_return_sequences=num_return_sequences, 17 | eos_token_id=tokenizer.eos_token_id, 18 | max_length=2048 19 | ) 20 | responses = [seq['generated_text'][prompt_len:].strip() for seq in sequences] 21 | return responses 22 | 23 | 24 | if __name__ == '__main__': 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--model', type=str) 27 | parser.add_argument('--n_prompts', type=int) 28 | parser.add_argument('--key_word', type=str, default=None) 29 | parser.add_argument('--output', type=str) 30 | parser.add_argument('--num_return_sequences', type=int, default=3) 31 | args = parser.parse_args() 32 | 33 | tokenizer = AutoTokenizer.from_pretrained(args.model) 34 | pipeline = transformers.pipeline( 35 | "text-generation", 36 | model=args.model, 37 | torch_dtype=torch.bfloat16, 38 | device_map="auto", 39 | ) 40 | 41 | system_prompt = '' 42 | 43 | if args.key_word: 44 | prompt = f'[INST] <>\n\n<>\n\nPlease imagine a random picture and describe it in one sentence. [/INST] A lone wolf stands proudly atop a snow-covered mountain peak, its piercing gaze reflecting both strength and solitude. [INST] Again [/INST]' 45 | response = 'A lone tree stands tall in a vast, snowy landscape, its branches adorned with delicate icicles glistening in the winter sunlight.' 46 | else: 47 | prompt = f'[INST] <>\n\n<>\n\nPlease imagine a random picture and describe it in one sentence:' 48 | response = 'A lone tree stands tall in a vast, snowy landscape, its branches adorned with delicate icicles glistening in the winter sunlight.' 49 | 50 | prompts = [] 51 | prompt_stack = [] 52 | 53 | user_msg = 'Again' 54 | prompt += f' {response} [INST] {user_msg} [/INST]' 55 | prompt_stack.append(prompt) 56 | 57 | n_pad_turn = 3 58 | n_pad = args.num_return_sequences ** (n_pad_turn + 1) - args.num_return_sequences 59 | n_prompts = args.n_prompts + n_pad 60 | 61 | progress_bar = tqdm(total=n_prompts) 62 | while len(prompts) < n_prompts: 63 | prompt = prompt_stack.pop(0) 64 | responses = get_prompt(prompt, args.num_return_sequences) 65 | prompts.extend(responses) 66 | prompt_stack.extend([f'{prompt} {resp} [INST] {user_msg} [/INST]' for resp in responses]) 67 | progress_bar.update(len(responses)) 68 | 69 | with open(args.output, 'w', encoding='utf-8') as f: 70 | json.dump(prompts[n_pad:n_prompts], f) 71 | -------------------------------------------------------------------------------- /code/VLEU Calculation/cal_vleu_clip.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from PIL import Image 8 | from tqdm import tqdm 9 | from transformers import CLIPProcessor, CLIPModel 10 | 11 | 12 | device = 'cuda' 13 | 14 | def calculate_vleu(image_dir, prompts, temperature=0.01): 15 | with torch.no_grad(): 16 | text_embs = [] 17 | for prompt in tqdm(prompts): 18 | inputs = processor([prompt], return_tensors='pt', truncation=True) 19 | inputs['input_ids'] = inputs['input_ids'].to(device) 20 | inputs['attention_mask'] = inputs['attention_mask'].to(device) 21 | outputs = model.get_text_features(**inputs) 22 | outputs /= outputs.norm(dim=-1, keepdim=True) 23 | text_embs.append(outputs) 24 | 25 | img_embs = [] 26 | img_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)] 27 | for img_path in tqdm(img_paths): 28 | image = Image.open(img_path) 29 | inputs = processor(images=image, return_tensors='pt') 30 | inputs['pixel_values'] = inputs['pixel_values'].to(device) 31 | outputs = model.get_image_features(**inputs) 32 | outputs /= outputs.norm(dim=-1, keepdim=True) 33 | img_embs.append(outputs) 34 | 35 | prob_matrix = [] 36 | for i in range(len(img_embs)): 37 | cosine_sim = [] 38 | for j in range(len(text_embs)): 39 | cosine_sim.append(img_embs[i] @ text_embs[j].T) 40 | prob = F.softmax(torch.tensor(cosine_sim) / temperature, dim=0) 41 | prob_matrix.append(prob) 42 | 43 | prob_matrix = torch.stack(prob_matrix) 44 | 45 | # marginal distribution for text embeddings 46 | text_emb_marginal_distribution = prob_matrix.sum(axis=0) / prob_matrix.shape[0] 47 | 48 | # KL divergence for each image 49 | image_kl_divergences = [] 50 | for i in range(prob_matrix.shape[0]): 51 | kl_divergence = (prob_matrix[i, :] * torch.log(prob_matrix[i, :] / text_emb_marginal_distribution)).sum().item() 52 | image_kl_divergences.append(kl_divergence) 53 | 54 | vleu_score = np.exp(sum(image_kl_divergences) / prob_matrix.shape[0]) 55 | return vleu_score 56 | 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser() 60 | parser.add_argument('--model', type=str) 61 | parser.add_argument('--prompt_json_path', type=str) 62 | parser.add_argument('--image_dir', type=str) 63 | args = parser.parse_args() 64 | 65 | model = CLIPModel.from_pretrained(args.model, local_files_only=True).to(device) 66 | processor = CLIPProcessor.from_pretrained(args.model, local_files_only=True) 67 | 68 | with open(args.prompt_json_path, 'r', encoding='utf-8') as f: 69 | prompts = json.load(f) 70 | 71 | score = calculate_vleu(args.image_dir, prompts, 0.01) 72 | print(score) 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /code/Text-Image Scoring/clip_score.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionPipeline, DiffusionPipeline 2 | from transformers import CLIPProcessor, CLIPModel 3 | import argparse 4 | import torch 5 | import torch.nn.functional as F 6 | import numpy as np 7 | import json 8 | import os 9 | 10 | 11 | def get_score(pipe, model, processor, prompts, batch_size=1, output_dir=None, temperature=0.01, device='cuda'): 12 | if output_dir: 13 | os.makedirs(output_dir, exist_ok=True) 14 | 15 | model = model.to(device) 16 | 17 | with torch.no_grad(): 18 | text_embs = [] 19 | img_embs = [] 20 | for prompt in prompts: 21 | inputs = processor([prompt], return_tensors='pt', truncation=True) 22 | inputs['input_ids'] = inputs['input_ids'].to(device) 23 | inputs['attention_mask'] = inputs['attention_mask'].to(device) 24 | outputs = model.get_text_features(**inputs) 25 | outputs /= outputs.norm(dim=-1, keepdim=True) 26 | text_embs.append(outputs) 27 | 28 | 29 | for i in range(0, len(prompts), batch_size): 30 | with torch.autocast("cuda"): 31 | images = pipe(prompts[i:i+batch_size]).images 32 | for j, image in enumerate(images): 33 | if output_dir: 34 | image.save(f'{output_dir}/{i+j}.jpg') 35 | inputs = processor(images=image, return_tensors='pt') 36 | inputs['pixel_values'] = inputs['pixel_values'].to(device) 37 | outputs = model.get_image_features(**inputs) 38 | outputs /= outputs.norm(dim=-1, keepdim=True) 39 | img_embs.append(outputs) 40 | 41 | prob_matrix = [] 42 | for i in range(len(img_embs)): 43 | cosine_sim = [] 44 | for j in range(len(text_embs)): 45 | cosine_sim.append(img_embs[i] @ text_embs[j].T) 46 | prob = F.softmax(torch.tensor(cosine_sim) / temperature, dim=0) 47 | prob_matrix.append(prob) 48 | 49 | prob_matrix = torch.stack(prob_matrix) 50 | 51 | # marginal distribution for text embeddings 52 | text_emb_marginal_distribution = prob_matrix.sum(axis=0) / prob_matrix.shape[0] 53 | 54 | # KL divergence for each image 55 | image_kl_divergences = [] 56 | for i in range(prob_matrix.shape[0]): 57 | kl_divergence = (prob_matrix[i, :] * torch.log(prob_matrix[i, :] / text_emb_marginal_distribution)).sum().item() 58 | image_kl_divergences.append(kl_divergence) 59 | 60 | vleu_score = np.exp(sum(image_kl_divergences) / prob_matrix.shape[0]) 61 | 62 | model = model.to('cpu') 63 | 64 | return vleu_score 65 | 66 | 67 | if __name__ == '__main__': 68 | parser = argparse.ArgumentParser() 69 | parser.add_argument('--clip_model', type=str) 70 | parser.add_argument('--model', type=str) 71 | parser.add_argument('--prompt_json_path', type=str) 72 | parser.add_argument('--image_dir', type=str) 73 | parser.add_argument('--batch_size', type=int, default=1) 74 | parser.add_argument('--n_prompt', type=int) 75 | parser.add_argument('--start', type=int, default=0) 76 | parser.add_argument('--sdxl', default=False, action='store_true') 77 | args = parser.parse_args() 78 | 79 | if args.sdxl: 80 | pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=torch.float16, safety_checker=None) 81 | else: 82 | pipe = StableDiffusionPipeline.from_pretrained(args.model, torch_dtype=torch.float16, safety_checker=None) 83 | pipe = pipe.to("cuda") 84 | 85 | model = CLIPModel.from_pretrained(args.clip_model) 86 | processor = CLIPProcessor.from_pretrained(args.clip_model) 87 | 88 | with open(args.prompt_json_path,'r',encoding='utf-8') as f: 89 | prompts = json.load(f) 90 | 91 | final_score = get_score(pipe, model, processor, prompts[args.start:args.start+args.n_prompt], args.batch_size, args.image_dir) 92 | print(f'Score: {final_score}') 93 | -------------------------------------------------------------------------------- /data/Constrained Subject/teddy_bear.json: -------------------------------------------------------------------------------- 1 | ["A fluffy, cream-colored teddy bear with a playful expression, wearing a polka dot bowtie.", "Sitting on a shelf, a vintage teddy bear with worn fur and a stitched smile exudes nostalgia and warmth.", "A small, fluffy teddy bear with soft, beige fur, wearing a tiny red bow around its neck, waiting to bring comfort and joy.", "Resting against a pile of colorful pillows, a worn and well-loved teddy bear with a faded bowtie and a stitched smile exudes warmth and nostalgia.", "Sitting on a shelf, a small, worn teddy bear with patches on its fur and a faded red bow around its neck exudes a timeless charm.", "A small, fluffy teddy bear with golden fur, adorned with a red bowtie, sitting on a bed, inviting endless snuggles.", "A small, fluffy teddy bear with golden fur, wearing a red bowtie and holding a heart-shaped pillow.", "A small, fluffy teddy bear with golden fur, wearing a polka dot bowtie and holding a heart-shaped pillow.", "A small, plush teddy bear with soft, golden fur, wearing a tiny blue bow around its neck, inviting warm hugs and endless snuggles.", "Resting against a pile of soft pillows, a well-loved, threadbare teddy bear with floppy arms and a worn-out bowtie exudes a timeless charm.", "In the picture, a small child holds a worn-out teddy bear tightly, finding comfort in its familiar embrace.", "A small, fluffy teddy bear with golden fur, wearing a polka dot bowtie and holding a heart-shaped pillow.", "A fluffy, cream-colored teddy bear with velvety paws, sparkling button eyes, and a sweet embroidered smile.", "A small, fluffy teddy bear with golden fur, adorned with a satin ribbon around its neck, sitting peacefully on a shelf.", "A small, fluffy teddy bear with soft, light brown fur, wearing a tiny red bowtie and sitting on a plaid cushion, ready to be loved and cherished.", "Resting against a pile of soft pillows, a huggable teddy bear with velvety brown fur, a shiny red bowtie, and a friendly expression invites endless cuddles and comforting companionship.", "A cuddly, brown teddy bear with button eyes and a soft, velvety texture eagerly awaits a loving embrace.", "A small, fluffy teddy bear with golden fur, wearing a red bowtie and holding a heart-shaped pillow.", "A cuddly, cream-colored teddy bear with a satin bowtie, patiently sitting on a shelf, ready to be cherished.", "A small, adorable teddy bear with light brown fur, button eyes, and a tiny red bowtie.", "Nestled among a pile of colorful toys, a worn and well-loved teddy bear with button eyes and a faded bow exudes comfort and nostalgia.", "A small, fluffy teddy bear with soft, golden fur, wearing a tiny blue bow around its neck, eagerly awaiting a warm embrace.", "A small, fluffy teddy bear with golden fur, wearing a dainty bowtie, sitting on a shelf surrounded by other cherished toys.", "A small, fluffy teddy bear with light brown fur, adorable button eyes, and a sweet smile, eagerly awaiting a loving hug.", "A small, fluffy teddy bear with soft, caramel-colored fur, wearing a dainty bowtie and sitting on a vintage rocking chair.", "Resting against a pile of soft pillows, a huggable teddy bear with velvety brown fur and a playful bowtie eagerly awaits its next adventure.", "A small, fluffy teddy bear with soft, golden fur, wearing a tiny blue bow around its neck, inviting endless hugs and snuggles.", "In this adorable picture, a small child clutches a well-loved teddy bear tightly, finding solace and companionship in its comforting presence.", "A fluffy, cream-colored teddy bear with velvety paws, sparkling black eyes, and a sweet, embroidered smile.", "Resting against a pile of soft pillows, a huggable teddy bear with chocolate-brown fur and a friendly smile eagerly awaited its next cuddle.", "A small, fluffy teddy bear with golden fur, wearing a dainty bowtie and clutching a tiny bouquet of flowers.", "Resting against a pile of colorful pillows, a huggable teddy bear with soft, caramel fur and a velvety nose eagerly awaits its next adventure.", "Resting against a pile of soft pillows, a huggable teddy bear with golden fur and a red bowtie sat, exuding comfort and companionship.", "Resting on a child's bed, a worn-out teddy bear with faded fur and a patchwork heart stitched on its chest brings comfort and nostalgia.", "Resting against a pile of pastel pillows, a vintage teddy bear with worn fur and a faded bow exudes timeless charm and cherished memories.", "A fluffy, cream-colored teddy bear with button eyes, a stitched nose, and outstretched arms, ready for a warm embrace.", "A cuddly, blue teddy bear with a satin bow, patiently waiting for a warm hug.", "A small, fluffy teddy bear with golden fur, wearing a red bowtie and sitting on a pile of colorful cushions.", "Nestled on a shelf, a brand new teddy bear with fluffy brown fur and a satin bow eagerly awaits its first cuddle.", "A fluffy, cream-colored teddy bear with velvety paws, wearing a dainty pink bow around its neck, sitting on a pile of colorful, confetti-filled balloons.", "A small, plush teddy bear with soft, golden fur, hugging a tiny bouquet of colorful flowers.", "A small, adorable teddy bear with light brown fur, button eyes, and a tiny stitched nose, patiently waiting for a warm embrace.", "Resting against a stack of storybooks, a well-loved teddy bear with soft, chocolate-brown fur and a velvety nose exudes warmth and companionship.", "In the picture, a cuddly, light brown teddy bear with button eyes and a satin bowtie sits on a grassy meadow, surrounded by colorful wildflowers, exuding warmth and comfort.", "In a sunlit meadow, a worn and well-loved teddy bear, adorned with a faded bow and stitched smile, sits atop a picnic blanket, patiently waiting for a child's warm embrace.", "A fluffy, adorable teddy bear with a bowtie and a mischievous twinkle in its button eyes eagerly awaits a new friend to share endless hugs and adventures with.", "A small, fluffy teddy bear with soft, caramel-colored fur, wearing a tiny bowtie and eagerly waiting for a warm embrace.", "A huggable teddy bear with velvety, chocolate-brown fur, adorable button eyes, and a heart-shaped nose, inviting warm embraces and endless snuggles.", "Resting on a sunlit windowsill, a worn and well-loved teddy bear with matted fur and a faded bow sits patiently, a silent guardian of cherished childhood memories.", "A well-loved, worn-out teddy bear with faded brown fur, missing an eye, and patched up with colorful stitches.", "A fluffy, cream-colored teddy bear with velvety paws and a twinkling button nose, sitting patiently with arms outstretched for a warm embrace.", "A cuddly, light brown teddy bear with button eyes and a soft, velvety texture sits patiently on a child's bed, eagerly waiting for bedtime snuggles and sweet dreams.", "A vintage teddy bear with worn, caramel-colored fur, stitched paws, and a well-loved expression on its face.", "In the picture, a group of teddy bears of various colors and sizes gather for a tea party, their adorable expressions filled with joy and playfulness.", "A vintage, well-loved teddy bear with worn-out patches, stitched together with love, sitting proudly on a rocking chair, reminiscing about countless childhood adventures.", "Nestled among a sea of stuffed animals, a plush teddy bear with soft, chocolate-brown fur and a friendly expression invites endless hugs and companionship.", "Perched on a wooden rocking chair, a fluffy teddy bear with floppy ears and a satin ribbon around its neck patiently awaits its next adventure with a gleeful sparkle in its button eyes.", "A well-loved teddy bear, worn with time, featuring patches of different fabrics and a heart-shaped button on its chest, holding cherished memories within its tattered seams.", "A well-loved, worn-out teddy bear with matted fur, missing an eye, and a stitched-up arm, but still cherished for all the memories it holds.", "A small, fluffy teddy bear with soft, light brown fur, adorable button eyes, and a sweet, embroidered smile, sitting patiently with its paws outstretched for a warm embrace.", "A well-loved teddy bear with worn fur, missing an eye, and a stitched-up paw, exuding a timeless charm and cherished memories.", "A well-loved teddy bear with faded golden fur, missing an eye and a few stitches, proudly displaying the wear and tear of countless adventures and comforting embraces.", "Nestled among a sea of vibrant wildflowers, a small, velvety teddy bear with a bowtie and a twinkle in its eyes brings comfort and joy to all who lay eyes upon it.", "Perched on a child's bed, a plush, rosy-cheeked teddy bear with twinkling black eyes and velvety fur patiently waits for bedtime stories and gentle hugs.", "A small, adorable teddy bear with golden fur and a red bowtie sits on a shelf, patiently waiting to be cuddled and cherished.", "A well-loved teddy bear, with worn-out fur and a few missing stitches, but still radiating warmth and comfort.", "A well-loved teddy bear, worn from years of affection, with button eyes and a patchwork of colorful patches, holding onto cherished memories.", "A huggable teddy bear with velvety, chocolate-brown fur, holding a tiny bouquet of wildflowers in its paws.", "Perched on a sunlit windowsill, a plush teddy bear with velvety fur and a heart-shaped nose gazes out with a sense of warmth and companionship.", "A well-loved teddy bear with worn-out patches, a faded smile, and a heart-shaped patch on its chest that reads \"Teddy Bear Hugs Forever\".", "Perched on a shelf, a well-loved teddy bear with worn fur and a faded blue ribbon around its neck stood as a cherished reminder of childhood innocence.", "A small, adorable teddy bear with soft, golden fur, sitting on a bed with a mischievous twinkle in its button eyes.", "A vintage, well-loved teddy bear with faded brown fur, missing an eye, and patched with colorful fabric, sitting patiently on a worn-out armchair.", "A small, plush teddy bear with soft, golden fur, a tiny red bowtie, and a mischievous twinkle in its button eyes.", "A worn-out, well-loved teddy bear with patches sewn onto its faded brown fur, sitting proudly on a rocking chair, a testament to years of comforting hugs.", "Nestled among a sea of colorful toys, a cherished teddy bear with plush, golden fur and a heartwarming smile brings joy and comfort to its young owner.", "A small, plush teddy bear with soft, golden fur, wearing a tiny blue sweater and holding a miniature picnic basket.", "A vintage, well-loved teddy bear with worn-out fur, missing an eye, and a patched-up paw, carrying years of cherished memories.", "Perched on a child's bed, a cherished teddy bear with golden fur and a playful twinkle in its button eyes patiently awaited bedtime stories and sweet dreams.", "A well-loved teddy bear with worn-out fur, missing an eye, and patched up with colorful stitches, radiating a timeless charm.", "A small, adorable teddy bear with soft, golden fur, sitting on a bed with a mischievous twinkle in its button eyes.", "Sitting on a shelf, a small, plush teddy bear with a red bowtie and a mischievous twinkle in its button eyes eagerly awaits a new adventure.", "Sitting on a patchwork quilt, a vintage teddy bear with golden fur and a well-loved appearance exudes warmth and nostalgia.", "Perched on a shelf, a cherished teddy bear with golden fur, a stitched nose, and outstretched arms exudes timeless charm, ready to be a steadfast friend through all of life's adventures.", "Perched on a child's bed, a fluffy teddy bear with a mischievous twinkle in its button eyes eagerly awaits bedtime stories and snuggles.", "A well-loved, worn-out teddy bear with faded golden fur, missing one eye, and a patched-up paw, holding onto cherished memories and endless love.", "Perched on a sunlit windowsill, a small, rosy-cheeked teddy bear with velvety fur and a satin bow eagerly awaits its next adventure with a gleam of anticipation in its glassy eyes.", "Perched on a sunlit windowsill, a cherished teddy bear with golden fur and a satin bow around its neck gazes out at the world with a sense of timeless love and warmth.", "A small, adorable teddy bear with soft, golden fur, a tiny red bow around its neck, and a playful expression on its face.", "Perched on a windowsill, a charming teddy bear with golden fur and a satin bow around its neck gazes out at the world with a sense of wonder and innocence.", "A well-loved teddy bear, with worn-out patches and a slightly faded fur, holding a heart-shaped pillow that says \"I love you.\"", "In the picture, a small, adorable teddy bear with soft, caramel-colored fur and a heart-shaped nose sits on a child's bed, patiently waiting for bedtime cuddles.", "A small, fuzzy teddy bear with chocolate-brown fur, button eyes, and a stitched smile, sitting on a child's bed surrounded by a collection of other stuffed animals.", "A well-loved teddy bear, worn from years of snuggles, with button eyes and a stitched smile, radiating warmth and nostalgia.", "A small, well-loved teddy bear with faded fur and a slightly crooked smile rests peacefully on a worn-out armchair, embodying years of cherished memories and comforting hugs.", "A beautifully handcrafted teddy bear, adorned with a satin bow and a gentle expression, sits patiently on a shelf, ready to be cherished by its lucky owner.", "Perched atop a child's bed, a plush teddy bear with soft, chocolate-brown fur and a playful twinkle in its button eyes eagerly awaits bedtime cuddles.", "A huggable teddy bear, made of plush, caramel-colored fur, with a velvety nose, twinkling black eyes, and a gentle expression that radiates comfort and love.", "Perched on a child's bed, a well-loved teddy bear with worn fur and a heart-shaped patch on its paw exudes a timeless charm and holds cherished memories within its stitched embrace.", "A small, adorable teddy bear with soft, golden fur, wearing a tiny blue bowtie and clutching a miniature book in its paws."] -------------------------------------------------------------------------------- /code/others/train_text_to_image.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import logging 18 | import math 19 | import os 20 | import json 21 | import random 22 | import shutil 23 | from pathlib import Path 24 | 25 | import accelerate 26 | import datasets 27 | import numpy as np 28 | import torch 29 | import torch.nn.functional as F 30 | import torch.utils.checkpoint 31 | import transformers 32 | from accelerate import Accelerator 33 | from accelerate.logging import get_logger 34 | from accelerate.state import AcceleratorState 35 | from accelerate.utils import ProjectConfiguration, set_seed 36 | from datasets import load_dataset 37 | from huggingface_hub import create_repo, upload_folder 38 | from packaging import version 39 | from torchvision import transforms 40 | from tqdm.auto import tqdm 41 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor, CLIPModel 42 | from transformers.utils import ContextManagers 43 | 44 | import diffusers 45 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel 46 | from diffusers.optimization import get_scheduler 47 | from diffusers.training_utils import EMAModel, compute_snr 48 | from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid 49 | from diffusers.utils.import_utils import is_xformers_available 50 | 51 | from clip_score import get_score 52 | 53 | 54 | if is_wandb_available(): 55 | import wandb 56 | 57 | 58 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 59 | check_min_version("0.25.0.dev0") 60 | 61 | logger = get_logger(__name__, log_level="INFO") 62 | 63 | DATASET_NAME_MAPPING = { 64 | "lambdalabs/pokemon-blip-captions": ("image", "text"), 65 | } 66 | 67 | 68 | def save_model_card( 69 | args, 70 | repo_id: str, 71 | images=None, 72 | repo_folder=None, 73 | ): 74 | img_str = "" 75 | if len(images) > 0: 76 | image_grid = make_image_grid(images, 1, len(args.validation_prompts)) 77 | image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png")) 78 | img_str += "![val_imgs_grid](./val_imgs_grid.png)\n" 79 | 80 | yaml = f""" 81 | --- 82 | license: creativeml-openrail-m 83 | base_model: {args.pretrained_model_name_or_path} 84 | datasets: 85 | - {args.dataset_name} 86 | tags: 87 | - stable-diffusion 88 | - stable-diffusion-diffusers 89 | - text-to-image 90 | - diffusers 91 | inference: true 92 | --- 93 | """ 94 | model_card = f""" 95 | # Text-to-image finetuning - {repo_id} 96 | 97 | This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n 98 | {img_str} 99 | 100 | ## Pipeline usage 101 | 102 | You can use the pipeline like so: 103 | 104 | ```python 105 | from diffusers import DiffusionPipeline 106 | import torch 107 | 108 | pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16) 109 | prompt = "{args.validation_prompts[0]}" 110 | image = pipeline(prompt).images[0] 111 | image.save("my_image.png") 112 | ``` 113 | 114 | ## Training info 115 | 116 | These are the key hyperparameters used during training: 117 | 118 | * Epochs: {args.num_train_epochs} 119 | * Learning rate: {args.learning_rate} 120 | * Batch size: {args.train_batch_size} 121 | * Gradient accumulation steps: {args.gradient_accumulation_steps} 122 | * Image resolution: {args.resolution} 123 | * Mixed-precision: {args.mixed_precision} 124 | 125 | """ 126 | wandb_info = "" 127 | if is_wandb_available(): 128 | wandb_run_url = None 129 | if wandb.run is not None: 130 | wandb_run_url = wandb.run.url 131 | 132 | if wandb_run_url is not None: 133 | wandb_info = f""" 134 | More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}). 135 | """ 136 | 137 | model_card += wandb_info 138 | 139 | with open(os.path.join(repo_folder, "README.md"), "w") as f: 140 | f.write(yaml + model_card) 141 | 142 | 143 | def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch): 144 | logger.info("Running validation... ") 145 | 146 | pipeline = StableDiffusionPipeline.from_pretrained( 147 | args.pretrained_model_name_or_path, 148 | vae=accelerator.unwrap_model(vae), 149 | text_encoder=accelerator.unwrap_model(text_encoder), 150 | tokenizer=tokenizer, 151 | unet=accelerator.unwrap_model(unet), 152 | safety_checker=None, 153 | revision=args.revision, 154 | variant=args.variant, 155 | torch_dtype=weight_dtype, 156 | ) 157 | pipeline = pipeline.to(accelerator.device) 158 | pipeline.set_progress_bar_config(disable=True) 159 | 160 | if args.enable_xformers_memory_efficient_attention: 161 | pipeline.enable_xformers_memory_efficient_attention() 162 | 163 | if args.seed is None: 164 | generator = None 165 | else: 166 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 167 | 168 | images = [] 169 | for i in range(len(args.validation_prompts)): 170 | with torch.autocast("cuda"): 171 | image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] 172 | 173 | images.append(image) 174 | 175 | for tracker in accelerator.trackers: 176 | if tracker.name == "tensorboard": 177 | np_images = np.stack([np.asarray(img) for img in images]) 178 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") 179 | elif tracker.name == "wandb": 180 | tracker.log( 181 | { 182 | "validation": [ 183 | wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") 184 | for i, image in enumerate(images) 185 | ] 186 | } 187 | ) 188 | else: 189 | logger.warn(f"image logging not implemented for {tracker.name}") 190 | 191 | del pipeline 192 | torch.cuda.empty_cache() 193 | 194 | return images 195 | 196 | 197 | def parse_args(): 198 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 199 | parser.add_argument( 200 | "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." 201 | ) 202 | parser.add_argument( 203 | "--pretrained_model_name_or_path", 204 | type=str, 205 | default=None, 206 | required=True, 207 | help="Path to pretrained model or model identifier from huggingface.co/models.", 208 | ) 209 | parser.add_argument( 210 | "--clip_model_name_or_path", 211 | type=str, 212 | default=None, 213 | required=True, 214 | help="Path to clip model or model identifier from huggingface.co/models.", 215 | ) 216 | parser.add_argument( 217 | "--revision", 218 | type=str, 219 | default=None, 220 | required=False, 221 | help="Revision of pretrained model identifier from huggingface.co/models.", 222 | ) 223 | parser.add_argument( 224 | "--variant", 225 | type=str, 226 | default=None, 227 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 228 | ) 229 | parser.add_argument( 230 | "--dataset_name", 231 | type=str, 232 | default=None, 233 | help=( 234 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 235 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 236 | " or to a folder containing files that 🤗 Datasets can understand." 237 | ), 238 | ) 239 | parser.add_argument( 240 | "--dataset_config_name", 241 | type=str, 242 | default=None, 243 | help="The config of the Dataset, leave as None if there's only one config.", 244 | ) 245 | parser.add_argument( 246 | "--train_data_dir", 247 | type=str, 248 | default=None, 249 | help=( 250 | "A folder containing the training data. Folder contents must follow the structure described in" 251 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 252 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 253 | ), 254 | ) 255 | parser.add_argument( 256 | "--image_column", type=str, default="image", help="The column of the dataset containing an image." 257 | ) 258 | parser.add_argument( 259 | "--caption_column", 260 | type=str, 261 | default="text", 262 | help="The column of the dataset containing a caption or a list of captions.", 263 | ) 264 | parser.add_argument( 265 | "--max_train_samples", 266 | type=int, 267 | default=None, 268 | help=( 269 | "For debugging purposes or quicker training, truncate the number of training examples to this " 270 | "value if set." 271 | ), 272 | ) 273 | parser.add_argument( 274 | "--validation_prompts", 275 | type=str, 276 | default=None, 277 | nargs="+", 278 | help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), 279 | ) 280 | parser.add_argument( 281 | "--validation_prompts_path", 282 | type=str, 283 | default=None 284 | ) 285 | parser.add_argument( 286 | "--output_dir", 287 | type=str, 288 | default="sd-model-finetuned", 289 | help="The output directory where the model predictions and checkpoints will be written.", 290 | ) 291 | parser.add_argument( 292 | "--cache_dir", 293 | type=str, 294 | default=None, 295 | help="The directory where the downloaded models and datasets will be stored.", 296 | ) 297 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 298 | parser.add_argument( 299 | "--resolution", 300 | type=int, 301 | default=512, 302 | help=( 303 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 304 | " resolution" 305 | ), 306 | ) 307 | parser.add_argument( 308 | "--center_crop", 309 | default=False, 310 | action="store_true", 311 | help=( 312 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 313 | " cropped. The images will be resized to the resolution first before cropping." 314 | ), 315 | ) 316 | parser.add_argument( 317 | "--random_flip", 318 | action="store_true", 319 | help="whether to randomly flip images horizontally", 320 | ) 321 | parser.add_argument( 322 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 323 | ) 324 | parser.add_argument("--num_train_epochs", type=int, default=100) 325 | parser.add_argument( 326 | "--max_train_steps", 327 | type=int, 328 | default=None, 329 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 330 | ) 331 | parser.add_argument( 332 | "--gradient_accumulation_steps", 333 | type=int, 334 | default=1, 335 | help="Number of updates steps to accumulate before performing a backward/update pass.", 336 | ) 337 | parser.add_argument( 338 | "--gradient_checkpointing", 339 | action="store_true", 340 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 341 | ) 342 | parser.add_argument( 343 | "--learning_rate", 344 | type=float, 345 | default=1e-4, 346 | help="Initial learning rate (after the potential warmup period) to use.", 347 | ) 348 | parser.add_argument( 349 | "--scale_lr", 350 | action="store_true", 351 | default=False, 352 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 353 | ) 354 | parser.add_argument( 355 | "--lr_scheduler", 356 | type=str, 357 | default="constant", 358 | help=( 359 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 360 | ' "constant", "constant_with_warmup"]' 361 | ), 362 | ) 363 | parser.add_argument( 364 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 365 | ) 366 | parser.add_argument( 367 | "--snr_gamma", 368 | type=float, 369 | default=None, 370 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " 371 | "More details here: https://arxiv.org/abs/2303.09556.", 372 | ) 373 | parser.add_argument( 374 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 375 | ) 376 | parser.add_argument( 377 | "--allow_tf32", 378 | action="store_true", 379 | help=( 380 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 381 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 382 | ), 383 | ) 384 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") 385 | parser.add_argument( 386 | "--non_ema_revision", 387 | type=str, 388 | default=None, 389 | required=False, 390 | help=( 391 | "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" 392 | " remote repository specified with --pretrained_model_name_or_path." 393 | ), 394 | ) 395 | parser.add_argument( 396 | "--dataloader_num_workers", 397 | type=int, 398 | default=0, 399 | help=( 400 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 401 | ), 402 | ) 403 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 404 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 405 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 406 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 407 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 408 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 409 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 410 | parser.add_argument( 411 | "--prediction_type", 412 | type=str, 413 | default=None, 414 | help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", 415 | ) 416 | parser.add_argument( 417 | "--hub_model_id", 418 | type=str, 419 | default=None, 420 | help="The name of the repository to keep in sync with the local `output_dir`.", 421 | ) 422 | parser.add_argument( 423 | "--logging_dir", 424 | type=str, 425 | default="logs", 426 | help=( 427 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 428 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 429 | ), 430 | ) 431 | parser.add_argument( 432 | "--mixed_precision", 433 | type=str, 434 | default=None, 435 | choices=["no", "fp16", "bf16"], 436 | help=( 437 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 438 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 439 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 440 | ), 441 | ) 442 | parser.add_argument( 443 | "--report_to", 444 | type=str, 445 | default="tensorboard", 446 | help=( 447 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 448 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 449 | ), 450 | ) 451 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 452 | parser.add_argument( 453 | "--checkpointing_steps", 454 | type=int, 455 | default=500, 456 | help=( 457 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 458 | " training using `--resume_from_checkpoint`." 459 | ), 460 | ) 461 | parser.add_argument( 462 | "--checkpoints_total_limit", 463 | type=int, 464 | default=None, 465 | help=("Max number of checkpoints to store."), 466 | ) 467 | parser.add_argument( 468 | "--resume_from_checkpoint", 469 | type=str, 470 | default=None, 471 | help=( 472 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 473 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 474 | ), 475 | ) 476 | parser.add_argument( 477 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 478 | ) 479 | parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") 480 | parser.add_argument( 481 | "--validation_epochs", 482 | type=int, 483 | default=5, 484 | help="Run validation every X epochs.", 485 | ) 486 | parser.add_argument( 487 | "--validation_steps", 488 | type=int, 489 | default=50, 490 | help="Run validation every X steps.", 491 | ) 492 | parser.add_argument( 493 | "--validation_output_dir", 494 | type=str, 495 | default=None 496 | ) 497 | parser.add_argument( 498 | "--validation_batch_size", 499 | type=int, 500 | default=1 501 | ) 502 | parser.add_argument( 503 | "--tracker_project_name", 504 | type=str, 505 | default="text2image-fine-tune", 506 | help=( 507 | "The `project_name` argument passed to Accelerator.init_trackers for" 508 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" 509 | ), 510 | ) 511 | 512 | args = parser.parse_args() 513 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 514 | if env_local_rank != -1 and env_local_rank != args.local_rank: 515 | args.local_rank = env_local_rank 516 | 517 | # Sanity checks 518 | if args.dataset_name is None and args.train_data_dir is None: 519 | raise ValueError("Need either a dataset name or a training folder.") 520 | 521 | # default to using the same revision for the non-ema model if not specified 522 | if args.non_ema_revision is None: 523 | args.non_ema_revision = args.revision 524 | 525 | return args 526 | 527 | 528 | def main(): 529 | args = parse_args() 530 | 531 | if args.non_ema_revision is not None: 532 | deprecate( 533 | "non_ema_revision!=None", 534 | "0.15.0", 535 | message=( 536 | "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" 537 | " use `--variant=non_ema` instead." 538 | ), 539 | ) 540 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 541 | 542 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 543 | 544 | accelerator = Accelerator( 545 | gradient_accumulation_steps=args.gradient_accumulation_steps, 546 | mixed_precision=args.mixed_precision, 547 | log_with=args.report_to, 548 | project_config=accelerator_project_config, 549 | ) 550 | 551 | # Make one log on every process with the configuration for debugging. 552 | logging.basicConfig( 553 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 554 | datefmt="%m/%d/%Y %H:%M:%S", 555 | level=logging.INFO, 556 | ) 557 | logger.info(accelerator.state, main_process_only=False) 558 | if accelerator.is_local_main_process: 559 | datasets.utils.logging.set_verbosity_warning() 560 | transformers.utils.logging.set_verbosity_warning() 561 | diffusers.utils.logging.set_verbosity_info() 562 | else: 563 | datasets.utils.logging.set_verbosity_error() 564 | transformers.utils.logging.set_verbosity_error() 565 | diffusers.utils.logging.set_verbosity_error() 566 | 567 | # If passed along, set the training seed now. 568 | if args.seed is not None: 569 | set_seed(args.seed) 570 | 571 | # Handle the repository creation 572 | if accelerator.is_main_process: 573 | if args.output_dir is not None: 574 | os.makedirs(args.output_dir, exist_ok=True) 575 | 576 | if args.push_to_hub: 577 | repo_id = create_repo( 578 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 579 | ).repo_id 580 | 581 | # Load scheduler, tokenizer and models. 582 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 583 | tokenizer = CLIPTokenizer.from_pretrained( 584 | args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision 585 | ) 586 | 587 | if args.clip_model_name_or_path: 588 | clip_model = CLIPModel.from_pretrained(args.clip_model_name_or_path) 589 | clip_processor = CLIPProcessor.from_pretrained(args.clip_model_name_or_path) 590 | 591 | def deepspeed_zero_init_disabled_context_manager(): 592 | """ 593 | returns either a context list that includes one that will disable zero.Init or an empty context list 594 | """ 595 | deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None 596 | if deepspeed_plugin is None: 597 | return [] 598 | 599 | return [deepspeed_plugin.zero3_init_context_manager(enable=False)] 600 | 601 | # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. 602 | # For this to work properly all models must be run through `accelerate.prepare`. But accelerate 603 | # will try to assign the same optimizer with the same weights to all models during 604 | # `deepspeed.initialize`, which of course doesn't work. 605 | # 606 | # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 607 | # frozen models from being partitioned during `zero.Init` which gets called during 608 | # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding 609 | # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. 610 | with ContextManagers(deepspeed_zero_init_disabled_context_manager()): 611 | text_encoder = CLIPTextModel.from_pretrained( 612 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant 613 | ) 614 | vae = AutoencoderKL.from_pretrained( 615 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant 616 | ) 617 | 618 | unet = UNet2DConditionModel.from_pretrained( 619 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision 620 | ) 621 | 622 | # Freeze vae and text_encoder and set unet to trainable 623 | vae.requires_grad_(False) 624 | text_encoder.requires_grad_(False) 625 | unet.train() 626 | 627 | # Create EMA for the unet. 628 | if args.use_ema: 629 | ema_unet = UNet2DConditionModel.from_pretrained( 630 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant 631 | ) 632 | ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) 633 | 634 | if args.enable_xformers_memory_efficient_attention: 635 | if is_xformers_available(): 636 | import xformers 637 | 638 | xformers_version = version.parse(xformers.__version__) 639 | if xformers_version == version.parse("0.0.16"): 640 | logger.warn( 641 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 642 | ) 643 | unet.enable_xformers_memory_efficient_attention() 644 | else: 645 | raise ValueError("xformers is not available. Make sure it is installed correctly") 646 | 647 | # `accelerate` 0.16.0 will have better support for customized saving 648 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 649 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 650 | def save_model_hook(models, weights, output_dir): 651 | if accelerator.is_main_process: 652 | if args.use_ema: 653 | ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) 654 | 655 | for i, model in enumerate(models): 656 | model.save_pretrained(os.path.join(output_dir, "unet")) 657 | 658 | # make sure to pop weight so that corresponding model is not saved again 659 | weights.pop() 660 | 661 | def load_model_hook(models, input_dir): 662 | if args.use_ema: 663 | load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) 664 | ema_unet.load_state_dict(load_model.state_dict()) 665 | ema_unet.to(accelerator.device) 666 | del load_model 667 | 668 | for i in range(len(models)): 669 | # pop models so that they are not loaded again 670 | model = models.pop() 671 | 672 | # load diffusers style into model 673 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") 674 | model.register_to_config(**load_model.config) 675 | 676 | model.load_state_dict(load_model.state_dict()) 677 | del load_model 678 | 679 | accelerator.register_save_state_pre_hook(save_model_hook) 680 | accelerator.register_load_state_pre_hook(load_model_hook) 681 | 682 | if args.gradient_checkpointing: 683 | unet.enable_gradient_checkpointing() 684 | 685 | # Enable TF32 for faster training on Ampere GPUs, 686 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 687 | if args.allow_tf32: 688 | torch.backends.cuda.matmul.allow_tf32 = True 689 | 690 | if args.scale_lr: 691 | args.learning_rate = ( 692 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 693 | ) 694 | 695 | # Initialize the optimizer 696 | if args.use_8bit_adam: 697 | try: 698 | import bitsandbytes as bnb 699 | except ImportError: 700 | raise ImportError( 701 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 702 | ) 703 | 704 | optimizer_cls = bnb.optim.AdamW8bit 705 | else: 706 | optimizer_cls = torch.optim.AdamW 707 | 708 | optimizer = optimizer_cls( 709 | unet.parameters(), 710 | lr=args.learning_rate, 711 | betas=(args.adam_beta1, args.adam_beta2), 712 | weight_decay=args.adam_weight_decay, 713 | eps=args.adam_epsilon, 714 | ) 715 | 716 | # Get the datasets: you can either provide your own training and evaluation files (see below) 717 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 718 | 719 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 720 | # download the dataset. 721 | if args.dataset_name is not None: 722 | # Downloading and loading a dataset from the hub. 723 | dataset = load_dataset( 724 | args.dataset_name, 725 | args.dataset_config_name, 726 | cache_dir=args.cache_dir, 727 | data_dir=args.train_data_dir, 728 | ) 729 | else: 730 | data_files = {} 731 | if args.train_data_dir is not None: 732 | data_files["train"] = os.path.join(args.train_data_dir, "**") 733 | dataset = load_dataset( 734 | "imagefolder", 735 | data_files=data_files, 736 | cache_dir=args.cache_dir, 737 | ) 738 | # See more about loading custom images at 739 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 740 | 741 | # Preprocessing the datasets. 742 | # We need to tokenize inputs and targets. 743 | column_names = dataset["train"].column_names 744 | 745 | # 6. Get the column names for input/target. 746 | dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) 747 | if args.image_column is None: 748 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 749 | else: 750 | image_column = args.image_column 751 | if image_column not in column_names: 752 | raise ValueError( 753 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" 754 | ) 755 | if args.caption_column is None: 756 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 757 | else: 758 | caption_column = args.caption_column 759 | if caption_column not in column_names: 760 | raise ValueError( 761 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" 762 | ) 763 | 764 | # Preprocessing the datasets. 765 | # We need to tokenize input captions and transform the images. 766 | def tokenize_captions(examples, is_train=True): 767 | captions = [] 768 | for caption in examples[caption_column]: 769 | if isinstance(caption, str): 770 | captions.append(caption) 771 | elif isinstance(caption, (list, np.ndarray)): 772 | # take a random caption if there are multiple 773 | captions.append(random.choice(caption) if is_train else caption[0]) 774 | else: 775 | raise ValueError( 776 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 777 | ) 778 | inputs = tokenizer( 779 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 780 | ) 781 | return inputs.input_ids 782 | 783 | # Preprocessing the datasets. 784 | train_transforms = transforms.Compose( 785 | [ 786 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 787 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 788 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), 789 | transforms.ToTensor(), 790 | transforms.Normalize([0.5], [0.5]), 791 | ] 792 | ) 793 | 794 | def preprocess_train(examples): 795 | images = [image.convert("RGB") for image in examples[image_column]] 796 | examples["pixel_values"] = [train_transforms(image) for image in images] 797 | examples["input_ids"] = tokenize_captions(examples) 798 | return examples 799 | 800 | with accelerator.main_process_first(): 801 | if args.max_train_samples is not None: 802 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 803 | # Set the training transforms 804 | train_dataset = dataset["train"].with_transform(preprocess_train) 805 | 806 | def collate_fn(examples): 807 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 808 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 809 | input_ids = torch.stack([example["input_ids"] for example in examples]) 810 | return {"pixel_values": pixel_values, "input_ids": input_ids} 811 | 812 | # DataLoaders creation: 813 | train_dataloader = torch.utils.data.DataLoader( 814 | train_dataset, 815 | shuffle=True, 816 | collate_fn=collate_fn, 817 | batch_size=args.train_batch_size, 818 | num_workers=args.dataloader_num_workers, 819 | ) 820 | 821 | # Scheduler and math around the number of training steps. 822 | overrode_max_train_steps = False 823 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 824 | if args.max_train_steps is None: 825 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 826 | overrode_max_train_steps = True 827 | 828 | lr_scheduler = get_scheduler( 829 | args.lr_scheduler, 830 | optimizer=optimizer, 831 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 832 | num_training_steps=args.max_train_steps * accelerator.num_processes, 833 | ) 834 | 835 | # Prepare everything with our `accelerator`. 836 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 837 | unet, optimizer, train_dataloader, lr_scheduler 838 | ) 839 | 840 | if args.use_ema: 841 | ema_unet.to(accelerator.device) 842 | 843 | # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision 844 | # as these weights are only used for inference, keeping weights in full precision is not required. 845 | weight_dtype = torch.float32 846 | if accelerator.mixed_precision == "fp16": 847 | weight_dtype = torch.float16 848 | args.mixed_precision = accelerator.mixed_precision 849 | elif accelerator.mixed_precision == "bf16": 850 | weight_dtype = torch.bfloat16 851 | args.mixed_precision = accelerator.mixed_precision 852 | 853 | # Move text_encode and vae to gpu and cast to weight_dtype 854 | text_encoder.to(accelerator.device, dtype=weight_dtype) 855 | vae.to(accelerator.device, dtype=weight_dtype) 856 | 857 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 858 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 859 | if overrode_max_train_steps: 860 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 861 | # Afterwards we recalculate our number of training epochs 862 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 863 | 864 | # We need to initialize the trackers we use, and also store our configuration. 865 | # The trackers initializes automatically on the main process. 866 | if accelerator.is_main_process: 867 | tracker_config = dict(vars(args)) 868 | tracker_config.pop("validation_prompts") 869 | accelerator.init_trackers(args.tracker_project_name, tracker_config) 870 | 871 | # Train! 872 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 873 | 874 | logger.info("***** Running training *****") 875 | logger.info(f" Num examples = {len(train_dataset)}") 876 | logger.info(f" Num Epochs = {args.num_train_epochs}") 877 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 878 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 879 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 880 | logger.info(f" Total optimization steps = {args.max_train_steps}") 881 | global_step = 0 882 | first_epoch = 0 883 | 884 | # Potentially load in the weights and states from a previous save 885 | if args.resume_from_checkpoint: 886 | if args.resume_from_checkpoint != "latest": 887 | path = os.path.basename(args.resume_from_checkpoint) 888 | else: 889 | # Get the most recent checkpoint 890 | dirs = os.listdir(args.output_dir) 891 | dirs = [d for d in dirs if d.startswith("checkpoint")] 892 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 893 | path = dirs[-1] if len(dirs) > 0 else None 894 | 895 | if path is None: 896 | accelerator.print( 897 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 898 | ) 899 | args.resume_from_checkpoint = None 900 | initial_global_step = 0 901 | else: 902 | accelerator.print(f"Resuming from checkpoint {path}") 903 | accelerator.load_state(os.path.join(args.output_dir, path)) 904 | global_step = int(path.split("-")[1]) 905 | 906 | initial_global_step = global_step 907 | first_epoch = global_step // num_update_steps_per_epoch 908 | 909 | else: 910 | initial_global_step = 0 911 | 912 | progress_bar = tqdm( 913 | range(0, args.max_train_steps), 914 | initial=initial_global_step, 915 | desc="Steps", 916 | # Only show the progress bar once on each machine. 917 | disable=not accelerator.is_local_main_process, 918 | ) 919 | 920 | for epoch in range(first_epoch, args.num_train_epochs): 921 | train_loss = 0.0 922 | for step, batch in enumerate(train_dataloader): 923 | with accelerator.accumulate(unet): 924 | # Convert images to latent space 925 | latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() 926 | latents = latents * vae.config.scaling_factor 927 | 928 | # Sample noise that we'll add to the latents 929 | noise = torch.randn_like(latents) 930 | if args.noise_offset: 931 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise 932 | noise += args.noise_offset * torch.randn( 933 | (latents.shape[0], latents.shape[1], 1, 1), device=latents.device 934 | ) 935 | if args.input_perturbation: 936 | new_noise = noise + args.input_perturbation * torch.randn_like(noise) 937 | bsz = latents.shape[0] 938 | # Sample a random timestep for each image 939 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 940 | timesteps = timesteps.long() 941 | 942 | # Add noise to the latents according to the noise magnitude at each timestep 943 | # (this is the forward diffusion process) 944 | if args.input_perturbation: 945 | noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) 946 | else: 947 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 948 | 949 | # Get the text embedding for conditioning 950 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 951 | 952 | # Get the target for loss depending on the prediction type 953 | if args.prediction_type is not None: 954 | # set prediction_type of scheduler if defined 955 | noise_scheduler.register_to_config(prediction_type=args.prediction_type) 956 | 957 | if noise_scheduler.config.prediction_type == "epsilon": 958 | target = noise 959 | elif noise_scheduler.config.prediction_type == "v_prediction": 960 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 961 | else: 962 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 963 | 964 | # Predict the noise residual and compute loss 965 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 966 | 967 | if args.snr_gamma is None: 968 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 969 | else: 970 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. 971 | # Since we predict the noise instead of x_0, the original formulation is slightly changed. 972 | # This is discussed in Section 4.2 of the same paper. 973 | snr = compute_snr(noise_scheduler, timesteps) 974 | if noise_scheduler.config.prediction_type == "v_prediction": 975 | # Velocity objective requires that we add one to SNR values before we divide by them. 976 | snr = snr + 1 977 | mse_loss_weights = ( 978 | torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr 979 | ) 980 | 981 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 982 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights 983 | loss = loss.mean() 984 | 985 | # Gather the losses across all processes for logging (if we use distributed training). 986 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 987 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 988 | 989 | # Backpropagate 990 | accelerator.backward(loss) 991 | if accelerator.sync_gradients: 992 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 993 | optimizer.step() 994 | lr_scheduler.step() 995 | optimizer.zero_grad() 996 | 997 | # Checks if the accelerator has performed an optimization step behind the scenes 998 | if accelerator.sync_gradients: 999 | if args.use_ema: 1000 | ema_unet.step(unet.parameters()) 1001 | progress_bar.update(1) 1002 | global_step += 1 1003 | accelerator.log({"train_loss": train_loss}, step=global_step) 1004 | train_loss = 0.0 1005 | 1006 | if global_step % args.checkpointing_steps == 0: 1007 | if accelerator.is_main_process: 1008 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 1009 | if args.checkpoints_total_limit is not None: 1010 | checkpoints = os.listdir(args.output_dir) 1011 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 1012 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 1013 | 1014 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 1015 | if len(checkpoints) >= args.checkpoints_total_limit: 1016 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 1017 | removing_checkpoints = checkpoints[0:num_to_remove] 1018 | 1019 | logger.info( 1020 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 1021 | ) 1022 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 1023 | 1024 | for removing_checkpoint in removing_checkpoints: 1025 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 1026 | shutil.rmtree(removing_checkpoint) 1027 | 1028 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 1029 | accelerator.save_state(save_path) 1030 | logger.info(f"Saved state to {save_path}") 1031 | 1032 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 1033 | progress_bar.set_postfix(**logs) 1034 | 1035 | if global_step >= args.max_train_steps: 1036 | break 1037 | 1038 | if accelerator.is_main_process: 1039 | if args.validation_prompts is not None and epoch % args.validation_epochs == 0: 1040 | if args.use_ema: 1041 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference. 1042 | ema_unet.store(unet.parameters()) 1043 | ema_unet.copy_to(unet.parameters()) 1044 | log_validation( 1045 | vae, 1046 | text_encoder, 1047 | tokenizer, 1048 | unet, 1049 | args, 1050 | accelerator, 1051 | weight_dtype, 1052 | global_step, 1053 | ) 1054 | if args.use_ema: 1055 | # Switch back to the original UNet parameters. 1056 | ema_unet.restore(unet.parameters()) 1057 | 1058 | if args.validation_prompts_path is not None and global_step % args.validation_steps == 0: 1059 | pipeline = StableDiffusionPipeline.from_pretrained( 1060 | args.pretrained_model_name_or_path, 1061 | vae=accelerator.unwrap_model(vae), 1062 | text_encoder=accelerator.unwrap_model(text_encoder), 1063 | tokenizer=tokenizer, 1064 | unet=accelerator.unwrap_model(unet), 1065 | safety_checker=None, 1066 | revision=args.revision, 1067 | variant=args.variant, 1068 | torch_dtype=weight_dtype, 1069 | ) 1070 | pipeline = pipeline.to(accelerator.device) 1071 | pipeline.set_progress_bar_config(disable=True) 1072 | if args.enable_xformers_memory_efficient_attention: 1073 | pipeline.enable_xformers_memory_efficient_attention() 1074 | 1075 | with open(args.validation_prompts_path,'r',encoding='utf-8') as f: 1076 | prompts = json.load(f)[:25] 1077 | 1078 | score = get_score(pipeline, clip_model, clip_processor, prompts, args.validation_batch_size, f'{args.validation_output_dir}/{global_step}') 1079 | print(f'{global_step}: {score}') 1080 | 1081 | # Create the pipeline using the trained modules and save it. 1082 | accelerator.wait_for_everyone() 1083 | if accelerator.is_main_process: 1084 | unet = accelerator.unwrap_model(unet) 1085 | if args.use_ema: 1086 | ema_unet.copy_to(unet.parameters()) 1087 | 1088 | pipeline = StableDiffusionPipeline.from_pretrained( 1089 | args.pretrained_model_name_or_path, 1090 | text_encoder=text_encoder, 1091 | vae=vae, 1092 | unet=unet, 1093 | revision=args.revision, 1094 | variant=args.variant, 1095 | ) 1096 | pipeline.save_pretrained(args.output_dir) 1097 | 1098 | # Run a final round of inference. 1099 | images = [] 1100 | if args.validation_prompts is not None: 1101 | logger.info("Running inference for collecting generated images...") 1102 | pipeline = pipeline.to(accelerator.device) 1103 | pipeline.torch_dtype = weight_dtype 1104 | pipeline.set_progress_bar_config(disable=True) 1105 | 1106 | if args.enable_xformers_memory_efficient_attention: 1107 | pipeline.enable_xformers_memory_efficient_attention() 1108 | 1109 | if args.seed is None: 1110 | generator = None 1111 | else: 1112 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 1113 | 1114 | for i in range(len(args.validation_prompts)): 1115 | with torch.autocast("cuda"): 1116 | image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] 1117 | images.append(image) 1118 | 1119 | if args.push_to_hub: 1120 | save_model_card(args, repo_id, images, repo_folder=args.output_dir) 1121 | upload_folder( 1122 | repo_id=repo_id, 1123 | folder_path=args.output_dir, 1124 | commit_message="End of training", 1125 | ignore_patterns=["step_*", "epoch_*"], 1126 | ) 1127 | 1128 | accelerator.end_training() 1129 | 1130 | 1131 | if __name__ == "__main__": 1132 | main() 1133 | -------------------------------------------------------------------------------- /code/others/train_text_to_image_sdxl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Fine-tuning script for Stable Diffusion XL for text2image.""" 17 | 18 | import argparse 19 | import functools 20 | import gc 21 | import logging 22 | import math 23 | import os 24 | import random 25 | import shutil 26 | from pathlib import Path 27 | 28 | import accelerate 29 | import datasets 30 | import numpy as np 31 | import torch 32 | import torch.nn.functional as F 33 | import torch.utils.checkpoint 34 | import transformers 35 | from accelerate import Accelerator 36 | from accelerate.logging import get_logger 37 | from accelerate.utils import ProjectConfiguration, set_seed 38 | from datasets import load_dataset 39 | from huggingface_hub import create_repo, upload_folder 40 | from packaging import version 41 | from torchvision import transforms 42 | from torchvision.transforms.functional import crop 43 | from tqdm.auto import tqdm 44 | from transformers import AutoTokenizer, PretrainedConfig 45 | 46 | import diffusers 47 | from diffusers import ( 48 | AutoencoderKL, 49 | DDPMScheduler, 50 | StableDiffusionXLPipeline, 51 | UNet2DConditionModel, 52 | ) 53 | from diffusers.optimization import get_scheduler 54 | from diffusers.training_utils import EMAModel, compute_snr 55 | from diffusers.utils import check_min_version, is_wandb_available 56 | from diffusers.utils.import_utils import is_xformers_available 57 | 58 | 59 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 60 | check_min_version("0.25.0.dev0") 61 | 62 | logger = get_logger(__name__) 63 | 64 | 65 | DATASET_NAME_MAPPING = { 66 | "lambdalabs/pokemon-blip-captions": ("image", "text"), 67 | } 68 | 69 | 70 | def save_model_card( 71 | repo_id: str, 72 | images=None, 73 | validation_prompt=None, 74 | base_model=str, 75 | dataset_name=str, 76 | repo_folder=None, 77 | vae_path=None, 78 | ): 79 | img_str = "" 80 | for i, image in enumerate(images): 81 | image.save(os.path.join(repo_folder, f"image_{i}.png")) 82 | img_str += f"![img_{i}](./image_{i}.png)\n" 83 | 84 | yaml = f""" 85 | --- 86 | license: creativeml-openrail-m 87 | base_model: {base_model} 88 | dataset: {dataset_name} 89 | tags: 90 | - stable-diffusion-xl 91 | - stable-diffusion-xl-diffusers 92 | - text-to-image 93 | - diffusers 94 | inference: true 95 | --- 96 | """ 97 | model_card = f""" 98 | # Text-to-image finetuning - {repo_id} 99 | 100 | This pipeline was finetuned from **{base_model}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n 101 | {img_str} 102 | 103 | Special VAE used for training: {vae_path}. 104 | """ 105 | with open(os.path.join(repo_folder, "README.md"), "w") as f: 106 | f.write(yaml + model_card) 107 | 108 | 109 | def import_model_class_from_model_name_or_path( 110 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 111 | ): 112 | text_encoder_config = PretrainedConfig.from_pretrained( 113 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision 114 | ) 115 | model_class = text_encoder_config.architectures[0] 116 | 117 | if model_class == "CLIPTextModel": 118 | from transformers import CLIPTextModel 119 | 120 | return CLIPTextModel 121 | elif model_class == "CLIPTextModelWithProjection": 122 | from transformers import CLIPTextModelWithProjection 123 | 124 | return CLIPTextModelWithProjection 125 | else: 126 | raise ValueError(f"{model_class} is not supported.") 127 | 128 | 129 | def parse_args(input_args=None): 130 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 131 | parser.add_argument( 132 | "--pretrained_model_name_or_path", 133 | type=str, 134 | default=None, 135 | required=True, 136 | help="Path to pretrained model or model identifier from huggingface.co/models.", 137 | ) 138 | parser.add_argument( 139 | "--pretrained_vae_model_name_or_path", 140 | type=str, 141 | default=None, 142 | help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", 143 | ) 144 | parser.add_argument( 145 | "--revision", 146 | type=str, 147 | default=None, 148 | required=False, 149 | help="Revision of pretrained model identifier from huggingface.co/models.", 150 | ) 151 | parser.add_argument( 152 | "--variant", 153 | type=str, 154 | default=None, 155 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 156 | ) 157 | parser.add_argument( 158 | "--dataset_name", 159 | type=str, 160 | default=None, 161 | help=( 162 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 163 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 164 | " or to a folder containing files that 🤗 Datasets can understand." 165 | ), 166 | ) 167 | parser.add_argument( 168 | "--dataset_config_name", 169 | type=str, 170 | default=None, 171 | help="The config of the Dataset, leave as None if there's only one config.", 172 | ) 173 | parser.add_argument( 174 | "--train_data_dir", 175 | type=str, 176 | default=None, 177 | help=( 178 | "A folder containing the training data. Folder contents must follow the structure described in" 179 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 180 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 181 | ), 182 | ) 183 | parser.add_argument( 184 | "--image_column", type=str, default="image", help="The column of the dataset containing an image." 185 | ) 186 | parser.add_argument( 187 | "--caption_column", 188 | type=str, 189 | default="text", 190 | help="The column of the dataset containing a caption or a list of captions.", 191 | ) 192 | parser.add_argument( 193 | "--validation_prompt", 194 | type=str, 195 | default=None, 196 | help="A prompt that is used during validation to verify that the model is learning.", 197 | ) 198 | parser.add_argument( 199 | "--num_validation_images", 200 | type=int, 201 | default=4, 202 | help="Number of images that should be generated during validation with `validation_prompt`.", 203 | ) 204 | parser.add_argument( 205 | "--validation_epochs", 206 | type=int, 207 | default=1, 208 | help=( 209 | "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" 210 | " `args.validation_prompt` multiple times: `args.num_validation_images`." 211 | ), 212 | ) 213 | parser.add_argument( 214 | "--max_train_samples", 215 | type=int, 216 | default=None, 217 | help=( 218 | "For debugging purposes or quicker training, truncate the number of training examples to this " 219 | "value if set." 220 | ), 221 | ) 222 | parser.add_argument( 223 | "--proportion_empty_prompts", 224 | type=float, 225 | default=0, 226 | help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", 227 | ) 228 | parser.add_argument( 229 | "--output_dir", 230 | type=str, 231 | default="sdxl-model-finetuned", 232 | help="The output directory where the model predictions and checkpoints will be written.", 233 | ) 234 | parser.add_argument( 235 | "--cache_dir", 236 | type=str, 237 | default=None, 238 | help="The directory where the downloaded models and datasets will be stored.", 239 | ) 240 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 241 | parser.add_argument( 242 | "--resolution", 243 | type=int, 244 | default=1024, 245 | help=( 246 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 247 | " resolution" 248 | ), 249 | ) 250 | parser.add_argument( 251 | "--center_crop", 252 | default=False, 253 | action="store_true", 254 | help=( 255 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 256 | " cropped. The images will be resized to the resolution first before cropping." 257 | ), 258 | ) 259 | parser.add_argument( 260 | "--random_flip", 261 | action="store_true", 262 | help="whether to randomly flip images horizontally", 263 | ) 264 | parser.add_argument( 265 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 266 | ) 267 | parser.add_argument("--num_train_epochs", type=int, default=100) 268 | parser.add_argument( 269 | "--max_train_steps", 270 | type=int, 271 | default=None, 272 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 273 | ) 274 | parser.add_argument( 275 | "--checkpointing_steps", 276 | type=int, 277 | default=500, 278 | help=( 279 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" 280 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" 281 | " training using `--resume_from_checkpoint`." 282 | ), 283 | ) 284 | parser.add_argument( 285 | "--checkpoints_total_limit", 286 | type=int, 287 | default=None, 288 | help=("Max number of checkpoints to store."), 289 | ) 290 | parser.add_argument( 291 | "--resume_from_checkpoint", 292 | type=str, 293 | default=None, 294 | help=( 295 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 296 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 297 | ), 298 | ) 299 | parser.add_argument( 300 | "--gradient_accumulation_steps", 301 | type=int, 302 | default=1, 303 | help="Number of updates steps to accumulate before performing a backward/update pass.", 304 | ) 305 | parser.add_argument( 306 | "--gradient_checkpointing", 307 | action="store_true", 308 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 309 | ) 310 | parser.add_argument( 311 | "--learning_rate", 312 | type=float, 313 | default=1e-4, 314 | help="Initial learning rate (after the potential warmup period) to use.", 315 | ) 316 | parser.add_argument( 317 | "--scale_lr", 318 | action="store_true", 319 | default=False, 320 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 321 | ) 322 | parser.add_argument( 323 | "--lr_scheduler", 324 | type=str, 325 | default="constant", 326 | help=( 327 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 328 | ' "constant", "constant_with_warmup"]' 329 | ), 330 | ) 331 | parser.add_argument( 332 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 333 | ) 334 | parser.add_argument( 335 | "--timestep_bias_strategy", 336 | type=str, 337 | default="none", 338 | choices=["earlier", "later", "range", "none"], 339 | help=( 340 | "The timestep bias strategy, which may help direct the model toward learning low or high frequency details." 341 | " Choices: ['earlier', 'later', 'range', 'none']." 342 | " The default is 'none', which means no bias is applied, and training proceeds normally." 343 | " The value of 'later' will increase the frequency of the model's final training timesteps." 344 | ), 345 | ) 346 | parser.add_argument( 347 | "--timestep_bias_multiplier", 348 | type=float, 349 | default=1.0, 350 | help=( 351 | "The multiplier for the bias. Defaults to 1.0, which means no bias is applied." 352 | " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it." 353 | ), 354 | ) 355 | parser.add_argument( 356 | "--timestep_bias_begin", 357 | type=int, 358 | default=0, 359 | help=( 360 | "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias." 361 | " Defaults to zero, which equates to having no specific bias." 362 | ), 363 | ) 364 | parser.add_argument( 365 | "--timestep_bias_end", 366 | type=int, 367 | default=1000, 368 | help=( 369 | "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias." 370 | " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on." 371 | ), 372 | ) 373 | parser.add_argument( 374 | "--timestep_bias_portion", 375 | type=float, 376 | default=0.25, 377 | help=( 378 | "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased." 379 | " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines" 380 | " whether the biased portions are in the earlier or later timesteps." 381 | ), 382 | ) 383 | parser.add_argument( 384 | "--snr_gamma", 385 | type=float, 386 | default=None, 387 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " 388 | "More details here: https://arxiv.org/abs/2303.09556.", 389 | ) 390 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") 391 | parser.add_argument( 392 | "--allow_tf32", 393 | action="store_true", 394 | help=( 395 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 396 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 397 | ), 398 | ) 399 | parser.add_argument( 400 | "--dataloader_num_workers", 401 | type=int, 402 | default=0, 403 | help=( 404 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 405 | ), 406 | ) 407 | parser.add_argument( 408 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 409 | ) 410 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 411 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 412 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 413 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 414 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 415 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 416 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 417 | parser.add_argument( 418 | "--prediction_type", 419 | type=str, 420 | default=None, 421 | help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", 422 | ) 423 | parser.add_argument( 424 | "--hub_model_id", 425 | type=str, 426 | default=None, 427 | help="The name of the repository to keep in sync with the local `output_dir`.", 428 | ) 429 | parser.add_argument( 430 | "--logging_dir", 431 | type=str, 432 | default="logs", 433 | help=( 434 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 435 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 436 | ), 437 | ) 438 | parser.add_argument( 439 | "--report_to", 440 | type=str, 441 | default="tensorboard", 442 | help=( 443 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 444 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 445 | ), 446 | ) 447 | parser.add_argument( 448 | "--mixed_precision", 449 | type=str, 450 | default=None, 451 | choices=["no", "fp16", "bf16"], 452 | help=( 453 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 454 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 455 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 456 | ), 457 | ) 458 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 459 | parser.add_argument( 460 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 461 | ) 462 | parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") 463 | 464 | if input_args is not None: 465 | args = parser.parse_args(input_args) 466 | else: 467 | args = parser.parse_args() 468 | 469 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 470 | if env_local_rank != -1 and env_local_rank != args.local_rank: 471 | args.local_rank = env_local_rank 472 | 473 | # Sanity checks 474 | if args.dataset_name is None and args.train_data_dir is None: 475 | raise ValueError("Need either a dataset name or a training folder.") 476 | 477 | if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: 478 | raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") 479 | 480 | return args 481 | 482 | 483 | # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt 484 | def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True): 485 | prompt_embeds_list = [] 486 | prompt_batch = batch[caption_column] 487 | 488 | captions = [] 489 | for caption in prompt_batch: 490 | if random.random() < proportion_empty_prompts: 491 | captions.append("") 492 | elif isinstance(caption, str): 493 | captions.append(caption) 494 | elif isinstance(caption, (list, np.ndarray)): 495 | # take a random caption if there are multiple 496 | captions.append(random.choice(caption) if is_train else caption[0]) 497 | 498 | with torch.no_grad(): 499 | for tokenizer, text_encoder in zip(tokenizers, text_encoders): 500 | text_inputs = tokenizer( 501 | captions, 502 | padding="max_length", 503 | max_length=tokenizer.model_max_length, 504 | truncation=True, 505 | return_tensors="pt", 506 | ) 507 | text_input_ids = text_inputs.input_ids 508 | prompt_embeds = text_encoder( 509 | text_input_ids.to(text_encoder.device), 510 | output_hidden_states=True, 511 | ) 512 | 513 | # We are only ALWAYS interested in the pooled output of the final text encoder 514 | pooled_prompt_embeds = prompt_embeds[0] 515 | prompt_embeds = prompt_embeds.hidden_states[-2] 516 | bs_embed, seq_len, _ = prompt_embeds.shape 517 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) 518 | prompt_embeds_list.append(prompt_embeds) 519 | 520 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) 521 | pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) 522 | return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()} 523 | 524 | 525 | def compute_vae_encodings(batch, vae): 526 | images = batch.pop("pixel_values") 527 | pixel_values = torch.stack(list(images)) 528 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 529 | pixel_values = pixel_values.to(vae.device, dtype=vae.dtype) 530 | 531 | with torch.no_grad(): 532 | model_input = vae.encode(pixel_values).latent_dist.sample() 533 | model_input = model_input * vae.config.scaling_factor 534 | return {"model_input": model_input.cpu()} 535 | 536 | 537 | def generate_timestep_weights(args, num_timesteps): 538 | weights = torch.ones(num_timesteps) 539 | 540 | # Determine the indices to bias 541 | num_to_bias = int(args.timestep_bias_portion * num_timesteps) 542 | 543 | if args.timestep_bias_strategy == "later": 544 | bias_indices = slice(-num_to_bias, None) 545 | elif args.timestep_bias_strategy == "earlier": 546 | bias_indices = slice(0, num_to_bias) 547 | elif args.timestep_bias_strategy == "range": 548 | # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500. 549 | range_begin = args.timestep_bias_begin 550 | range_end = args.timestep_bias_end 551 | if range_begin < 0: 552 | raise ValueError( 553 | "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero." 554 | ) 555 | if range_end > num_timesteps: 556 | raise ValueError( 557 | "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps." 558 | ) 559 | bias_indices = slice(range_begin, range_end) 560 | else: # 'none' or any other string 561 | return weights 562 | if args.timestep_bias_multiplier <= 0: 563 | return ValueError( 564 | "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps." 565 | " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead." 566 | " A timestep bias multiplier less than or equal to 0 is not allowed." 567 | ) 568 | 569 | # Apply the bias 570 | weights[bias_indices] *= args.timestep_bias_multiplier 571 | 572 | # Normalize 573 | weights /= weights.sum() 574 | 575 | return weights 576 | 577 | 578 | def main(args): 579 | logging_dir = Path(args.output_dir, args.logging_dir) 580 | 581 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 582 | 583 | accelerator = Accelerator( 584 | gradient_accumulation_steps=args.gradient_accumulation_steps, 585 | mixed_precision=args.mixed_precision, 586 | log_with=args.report_to, 587 | project_config=accelerator_project_config, 588 | ) 589 | 590 | if args.report_to == "wandb": 591 | if not is_wandb_available(): 592 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 593 | import wandb 594 | 595 | # Make one log on every process with the configuration for debugging. 596 | logging.basicConfig( 597 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 598 | datefmt="%m/%d/%Y %H:%M:%S", 599 | level=logging.INFO, 600 | ) 601 | logger.info(accelerator.state, main_process_only=False) 602 | if accelerator.is_local_main_process: 603 | datasets.utils.logging.set_verbosity_warning() 604 | transformers.utils.logging.set_verbosity_warning() 605 | diffusers.utils.logging.set_verbosity_info() 606 | else: 607 | datasets.utils.logging.set_verbosity_error() 608 | transformers.utils.logging.set_verbosity_error() 609 | diffusers.utils.logging.set_verbosity_error() 610 | 611 | # If passed along, set the training seed now. 612 | if args.seed is not None: 613 | set_seed(args.seed) 614 | 615 | # Handle the repository creation 616 | if accelerator.is_main_process: 617 | if args.output_dir is not None: 618 | os.makedirs(args.output_dir, exist_ok=True) 619 | 620 | if args.push_to_hub: 621 | repo_id = create_repo( 622 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 623 | ).repo_id 624 | 625 | # Load the tokenizers 626 | tokenizer_one = AutoTokenizer.from_pretrained( 627 | args.pretrained_model_name_or_path, 628 | subfolder="tokenizer", 629 | revision=args.revision, 630 | use_fast=False, 631 | ) 632 | tokenizer_two = AutoTokenizer.from_pretrained( 633 | args.pretrained_model_name_or_path, 634 | subfolder="tokenizer_2", 635 | revision=args.revision, 636 | use_fast=False, 637 | ) 638 | 639 | # import correct text encoder classes 640 | text_encoder_cls_one = import_model_class_from_model_name_or_path( 641 | args.pretrained_model_name_or_path, args.revision 642 | ) 643 | text_encoder_cls_two = import_model_class_from_model_name_or_path( 644 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" 645 | ) 646 | 647 | # Load scheduler and models 648 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 649 | # Check for terminal SNR in combination with SNR Gamma 650 | text_encoder_one = text_encoder_cls_one.from_pretrained( 651 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant 652 | ) 653 | text_encoder_two = text_encoder_cls_two.from_pretrained( 654 | args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant 655 | ) 656 | vae_path = ( 657 | args.pretrained_model_name_or_path 658 | if args.pretrained_vae_model_name_or_path is None 659 | else args.pretrained_vae_model_name_or_path 660 | ) 661 | vae = AutoencoderKL.from_pretrained( 662 | vae_path, 663 | subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, 664 | revision=args.revision, 665 | variant=args.variant, 666 | ) 667 | unet = UNet2DConditionModel.from_pretrained( 668 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant 669 | ) 670 | 671 | # Freeze vae and text encoders. 672 | vae.requires_grad_(False) 673 | text_encoder_one.requires_grad_(False) 674 | text_encoder_two.requires_grad_(False) 675 | # Set unet as trainable. 676 | unet.train() 677 | 678 | # For mixed precision training we cast all non-trainable weigths to half-precision 679 | # as these weights are only used for inference, keeping weights in full precision is not required. 680 | weight_dtype = torch.float32 681 | if accelerator.mixed_precision == "fp16": 682 | weight_dtype = torch.float16 683 | elif accelerator.mixed_precision == "bf16": 684 | weight_dtype = torch.bfloat16 685 | 686 | # Move unet, vae and text_encoder to device and cast to weight_dtype 687 | # The VAE is in float32 to avoid NaN losses. 688 | vae.to(accelerator.device, dtype=torch.float32) 689 | text_encoder_one.to(accelerator.device, dtype=weight_dtype) 690 | text_encoder_two.to(accelerator.device, dtype=weight_dtype) 691 | 692 | # Create EMA for the unet. 693 | if args.use_ema: 694 | ema_unet = UNet2DConditionModel.from_pretrained( 695 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant 696 | ) 697 | ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) 698 | 699 | if args.enable_xformers_memory_efficient_attention: 700 | if is_xformers_available(): 701 | import xformers 702 | 703 | xformers_version = version.parse(xformers.__version__) 704 | if xformers_version == version.parse("0.0.16"): 705 | logger.warn( 706 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 707 | ) 708 | unet.enable_xformers_memory_efficient_attention() 709 | else: 710 | raise ValueError("xformers is not available. Make sure it is installed correctly") 711 | 712 | # `accelerate` 0.16.0 will have better support for customized saving 713 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 714 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 715 | def save_model_hook(models, weights, output_dir): 716 | if accelerator.is_main_process: 717 | if args.use_ema: 718 | ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) 719 | 720 | for i, model in enumerate(models): 721 | model.save_pretrained(os.path.join(output_dir, "unet")) 722 | 723 | # make sure to pop weight so that corresponding model is not saved again 724 | weights.pop() 725 | 726 | def load_model_hook(models, input_dir): 727 | if args.use_ema: 728 | load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) 729 | ema_unet.load_state_dict(load_model.state_dict()) 730 | ema_unet.to(accelerator.device) 731 | del load_model 732 | 733 | for i in range(len(models)): 734 | # pop models so that they are not loaded again 735 | model = models.pop() 736 | 737 | # load diffusers style into model 738 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") 739 | model.register_to_config(**load_model.config) 740 | 741 | model.load_state_dict(load_model.state_dict()) 742 | del load_model 743 | 744 | accelerator.register_save_state_pre_hook(save_model_hook) 745 | accelerator.register_load_state_pre_hook(load_model_hook) 746 | 747 | if args.gradient_checkpointing: 748 | unet.enable_gradient_checkpointing() 749 | 750 | # Enable TF32 for faster training on Ampere GPUs, 751 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 752 | if args.allow_tf32: 753 | torch.backends.cuda.matmul.allow_tf32 = True 754 | 755 | if args.scale_lr: 756 | args.learning_rate = ( 757 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 758 | ) 759 | 760 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 761 | if args.use_8bit_adam: 762 | try: 763 | import bitsandbytes as bnb 764 | except ImportError: 765 | raise ImportError( 766 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 767 | ) 768 | 769 | optimizer_class = bnb.optim.AdamW8bit 770 | else: 771 | optimizer_class = torch.optim.AdamW 772 | 773 | # Optimizer creation 774 | params_to_optimize = unet.parameters() 775 | optimizer = optimizer_class( 776 | params_to_optimize, 777 | lr=args.learning_rate, 778 | betas=(args.adam_beta1, args.adam_beta2), 779 | weight_decay=args.adam_weight_decay, 780 | eps=args.adam_epsilon, 781 | ) 782 | 783 | # Get the datasets: you can either provide your own training and evaluation files (see below) 784 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 785 | 786 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 787 | # download the dataset. 788 | if args.dataset_name is not None: 789 | # Downloading and loading a dataset from the hub. 790 | dataset = load_dataset( 791 | args.dataset_name, 792 | args.dataset_config_name, 793 | cache_dir=args.cache_dir, 794 | ) 795 | else: 796 | data_files = {} 797 | if args.train_data_dir is not None: 798 | data_files["train"] = os.path.join(args.train_data_dir, "**") 799 | dataset = load_dataset( 800 | "imagefolder", 801 | data_files=data_files, 802 | cache_dir=args.cache_dir, 803 | ) 804 | # See more about loading custom images at 805 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 806 | 807 | # Preprocessing the datasets. 808 | # We need to tokenize inputs and targets. 809 | column_names = dataset["train"].column_names 810 | 811 | # 6. Get the column names for input/target. 812 | dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) 813 | if args.image_column is None: 814 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 815 | else: 816 | image_column = args.image_column 817 | if image_column not in column_names: 818 | raise ValueError( 819 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" 820 | ) 821 | if args.caption_column is None: 822 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 823 | else: 824 | caption_column = args.caption_column 825 | if caption_column not in column_names: 826 | raise ValueError( 827 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" 828 | ) 829 | 830 | # Preprocessing the datasets. 831 | train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR) 832 | train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution) 833 | train_flip = transforms.RandomHorizontalFlip(p=1.0) 834 | train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]) 835 | 836 | def preprocess_train(examples): 837 | images = [image.convert("RGB") for image in examples[image_column]] 838 | # image aug 839 | original_sizes = [] 840 | all_images = [] 841 | crop_top_lefts = [] 842 | for image in images: 843 | original_sizes.append((image.height, image.width)) 844 | image = train_resize(image) 845 | if args.center_crop: 846 | y1 = max(0, int(round((image.height - args.resolution) / 2.0))) 847 | x1 = max(0, int(round((image.width - args.resolution) / 2.0))) 848 | image = train_crop(image) 849 | else: 850 | y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) 851 | image = crop(image, y1, x1, h, w) 852 | if args.random_flip and random.random() < 0.5: 853 | # flip 854 | x1 = image.width - x1 855 | image = train_flip(image) 856 | crop_top_left = (y1, x1) 857 | crop_top_lefts.append(crop_top_left) 858 | image = train_transforms(image) 859 | all_images.append(image) 860 | 861 | examples["original_sizes"] = original_sizes 862 | examples["crop_top_lefts"] = crop_top_lefts 863 | examples["pixel_values"] = all_images 864 | return examples 865 | 866 | with accelerator.main_process_first(): 867 | if args.max_train_samples is not None: 868 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 869 | # Set the training transforms 870 | train_dataset = dataset["train"].with_transform(preprocess_train) 871 | 872 | # Let's first compute all the embeddings so that we can free up the text encoders 873 | # from memory. We will pre-compute the VAE encodings too. 874 | text_encoders = [text_encoder_one, text_encoder_two] 875 | tokenizers = [tokenizer_one, tokenizer_two] 876 | compute_embeddings_fn = functools.partial( 877 | encode_prompt, 878 | text_encoders=text_encoders, 879 | tokenizers=tokenizers, 880 | proportion_empty_prompts=args.proportion_empty_prompts, 881 | caption_column=args.caption_column, 882 | ) 883 | compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae) 884 | with accelerator.main_process_first(): 885 | from datasets.fingerprint import Hasher 886 | 887 | # fingerprint used by the cache for the other processes to load the result 888 | # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 889 | new_fingerprint = Hasher.hash(args) 890 | new_fingerprint_for_vae = Hasher.hash("vae") 891 | train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) 892 | train_dataset = train_dataset.map( 893 | compute_vae_encodings_fn, 894 | batched=True, 895 | batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps, 896 | new_fingerprint=new_fingerprint_for_vae, 897 | ) 898 | 899 | del text_encoders, tokenizers, vae 900 | gc.collect() 901 | torch.cuda.empty_cache() 902 | 903 | def collate_fn(examples): 904 | model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples]) 905 | original_sizes = [example["original_sizes"] for example in examples] 906 | crop_top_lefts = [example["crop_top_lefts"] for example in examples] 907 | prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) 908 | pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples]) 909 | 910 | return { 911 | "model_input": model_input, 912 | "prompt_embeds": prompt_embeds, 913 | "pooled_prompt_embeds": pooled_prompt_embeds, 914 | "original_sizes": original_sizes, 915 | "crop_top_lefts": crop_top_lefts, 916 | } 917 | 918 | # DataLoaders creation: 919 | train_dataloader = torch.utils.data.DataLoader( 920 | train_dataset, 921 | shuffle=True, 922 | collate_fn=collate_fn, 923 | batch_size=args.train_batch_size, 924 | num_workers=args.dataloader_num_workers, 925 | ) 926 | 927 | # Scheduler and math around the number of training steps. 928 | overrode_max_train_steps = False 929 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 930 | if args.max_train_steps is None: 931 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 932 | overrode_max_train_steps = True 933 | 934 | lr_scheduler = get_scheduler( 935 | args.lr_scheduler, 936 | optimizer=optimizer, 937 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 938 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 939 | ) 940 | 941 | # Prepare everything with our `accelerator`. 942 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 943 | unet, optimizer, train_dataloader, lr_scheduler 944 | ) 945 | 946 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 947 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 948 | if overrode_max_train_steps: 949 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 950 | # Afterwards we recalculate our number of training epochs 951 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 952 | 953 | # We need to initialize the trackers we use, and also store our configuration. 954 | # The trackers initializes automatically on the main process. 955 | if accelerator.is_main_process: 956 | accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args)) 957 | 958 | # Train! 959 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 960 | 961 | logger.info("***** Running training *****") 962 | logger.info(f" Num examples = {len(train_dataset)}") 963 | logger.info(f" Num Epochs = {args.num_train_epochs}") 964 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 965 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 966 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 967 | logger.info(f" Total optimization steps = {args.max_train_steps}") 968 | global_step = 0 969 | first_epoch = 0 970 | 971 | # Potentially load in the weights and states from a previous save 972 | if args.resume_from_checkpoint: 973 | if args.resume_from_checkpoint != "latest": 974 | path = os.path.basename(args.resume_from_checkpoint) 975 | else: 976 | # Get the most recent checkpoint 977 | dirs = os.listdir(args.output_dir) 978 | dirs = [d for d in dirs if d.startswith("checkpoint")] 979 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 980 | path = dirs[-1] if len(dirs) > 0 else None 981 | 982 | if path is None: 983 | accelerator.print( 984 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 985 | ) 986 | args.resume_from_checkpoint = None 987 | initial_global_step = 0 988 | else: 989 | accelerator.print(f"Resuming from checkpoint {path}") 990 | accelerator.load_state(os.path.join(args.output_dir, path)) 991 | global_step = int(path.split("-")[1]) 992 | 993 | initial_global_step = global_step 994 | first_epoch = global_step // num_update_steps_per_epoch 995 | 996 | else: 997 | initial_global_step = 0 998 | 999 | progress_bar = tqdm( 1000 | range(0, args.max_train_steps), 1001 | initial=initial_global_step, 1002 | desc="Steps", 1003 | # Only show the progress bar once on each machine. 1004 | disable=not accelerator.is_local_main_process, 1005 | ) 1006 | 1007 | for epoch in range(first_epoch, args.num_train_epochs): 1008 | train_loss = 0.0 1009 | for step, batch in enumerate(train_dataloader): 1010 | with accelerator.accumulate(unet): 1011 | # Sample noise that we'll add to the latents 1012 | model_input = batch["model_input"].to(accelerator.device) 1013 | noise = torch.randn_like(model_input) 1014 | if args.noise_offset: 1015 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise 1016 | noise += args.noise_offset * torch.randn( 1017 | (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device 1018 | ) 1019 | 1020 | bsz = model_input.shape[0] 1021 | if args.timestep_bias_strategy == "none": 1022 | # Sample a random timestep for each image without bias. 1023 | timesteps = torch.randint( 1024 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device 1025 | ) 1026 | else: 1027 | # Sample a random timestep for each image, potentially biased by the timestep weights. 1028 | # Biasing the timestep weights allows us to spend less time training irrelevant timesteps. 1029 | weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to( 1030 | model_input.device 1031 | ) 1032 | timesteps = torch.multinomial(weights, bsz, replacement=True).long() 1033 | 1034 | # Add noise to the model input according to the noise magnitude at each timestep 1035 | # (this is the forward diffusion process) 1036 | noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) 1037 | 1038 | # time ids 1039 | def compute_time_ids(original_size, crops_coords_top_left): 1040 | # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids 1041 | target_size = (args.resolution, args.resolution) 1042 | add_time_ids = list(original_size + crops_coords_top_left + target_size) 1043 | add_time_ids = torch.tensor([add_time_ids]) 1044 | add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) 1045 | return add_time_ids 1046 | 1047 | add_time_ids = torch.cat( 1048 | [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])] 1049 | ) 1050 | 1051 | # Predict the noise residual 1052 | unet_added_conditions = {"time_ids": add_time_ids} 1053 | prompt_embeds = batch["prompt_embeds"].to(accelerator.device) 1054 | pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) 1055 | unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) 1056 | model_pred = unet( 1057 | noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions 1058 | ).sample 1059 | 1060 | # Get the target for loss depending on the prediction type 1061 | if args.prediction_type is not None: 1062 | # set prediction_type of scheduler if defined 1063 | noise_scheduler.register_to_config(prediction_type=args.prediction_type) 1064 | 1065 | if noise_scheduler.config.prediction_type == "epsilon": 1066 | target = noise 1067 | elif noise_scheduler.config.prediction_type == "v_prediction": 1068 | target = noise_scheduler.get_velocity(model_input, noise, timesteps) 1069 | elif noise_scheduler.config.prediction_type == "sample": 1070 | # We set the target to latents here, but the model_pred will return the noise sample prediction. 1071 | target = model_input 1072 | # We will have to subtract the noise residual from the prediction to get the target sample. 1073 | model_pred = model_pred - noise 1074 | else: 1075 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 1076 | 1077 | if args.snr_gamma is None: 1078 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 1079 | else: 1080 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. 1081 | # Since we predict the noise instead of x_0, the original formulation is slightly changed. 1082 | # This is discussed in Section 4.2 of the same paper. 1083 | snr = compute_snr(noise_scheduler, timesteps) 1084 | if noise_scheduler.config.prediction_type == "v_prediction": 1085 | # Velocity objective requires that we add one to SNR values before we divide by them. 1086 | snr = snr + 1 1087 | mse_loss_weights = ( 1088 | torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr 1089 | ) 1090 | 1091 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 1092 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights 1093 | loss = loss.mean() 1094 | 1095 | # Gather the losses across all processes for logging (if we use distributed training). 1096 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 1097 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 1098 | 1099 | # Backpropagate 1100 | accelerator.backward(loss) 1101 | if accelerator.sync_gradients: 1102 | params_to_clip = unet.parameters() 1103 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 1104 | optimizer.step() 1105 | lr_scheduler.step() 1106 | optimizer.zero_grad() 1107 | 1108 | # Checks if the accelerator has performed an optimization step behind the scenes 1109 | if accelerator.sync_gradients: 1110 | progress_bar.update(1) 1111 | global_step += 1 1112 | accelerator.log({"train_loss": train_loss}, step=global_step) 1113 | train_loss = 0.0 1114 | 1115 | if accelerator.is_main_process: 1116 | if global_step % args.checkpointing_steps == 0: 1117 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 1118 | if args.checkpoints_total_limit is not None: 1119 | checkpoints = os.listdir(args.output_dir) 1120 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 1121 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 1122 | 1123 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 1124 | if len(checkpoints) >= args.checkpoints_total_limit: 1125 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 1126 | removing_checkpoints = checkpoints[0:num_to_remove] 1127 | 1128 | logger.info( 1129 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 1130 | ) 1131 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 1132 | 1133 | for removing_checkpoint in removing_checkpoints: 1134 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 1135 | shutil.rmtree(removing_checkpoint) 1136 | 1137 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 1138 | accelerator.save_state(save_path) 1139 | logger.info(f"Saved state to {save_path}") 1140 | 1141 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 1142 | progress_bar.set_postfix(**logs) 1143 | 1144 | if global_step >= args.max_train_steps: 1145 | break 1146 | 1147 | if accelerator.is_main_process: 1148 | if args.validation_prompt is not None and epoch % args.validation_epochs == 0: 1149 | logger.info( 1150 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 1151 | f" {args.validation_prompt}." 1152 | ) 1153 | if args.use_ema: 1154 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference. 1155 | ema_unet.store(unet.parameters()) 1156 | ema_unet.copy_to(unet.parameters()) 1157 | 1158 | # create pipeline 1159 | vae = AutoencoderKL.from_pretrained( 1160 | vae_path, 1161 | subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, 1162 | revision=args.revision, 1163 | variant=args.variant, 1164 | ) 1165 | pipeline = StableDiffusionXLPipeline.from_pretrained( 1166 | args.pretrained_model_name_or_path, 1167 | vae=vae, 1168 | unet=accelerator.unwrap_model(unet), 1169 | revision=args.revision, 1170 | variant=args.variant, 1171 | torch_dtype=weight_dtype, 1172 | ) 1173 | if args.prediction_type is not None: 1174 | scheduler_args = {"prediction_type": args.prediction_type} 1175 | pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args) 1176 | 1177 | pipeline = pipeline.to(accelerator.device) 1178 | pipeline.set_progress_bar_config(disable=True) 1179 | 1180 | # run inference 1181 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None 1182 | pipeline_args = {"prompt": args.validation_prompt} 1183 | 1184 | with torch.cuda.amp.autocast(): 1185 | images = [ 1186 | pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0] 1187 | for _ in range(args.num_validation_images) 1188 | ] 1189 | 1190 | for tracker in accelerator.trackers: 1191 | if tracker.name == "tensorboard": 1192 | np_images = np.stack([np.asarray(img) for img in images]) 1193 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") 1194 | if tracker.name == "wandb": 1195 | tracker.log( 1196 | { 1197 | "validation": [ 1198 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") 1199 | for i, image in enumerate(images) 1200 | ] 1201 | } 1202 | ) 1203 | 1204 | del pipeline 1205 | torch.cuda.empty_cache() 1206 | 1207 | accelerator.wait_for_everyone() 1208 | if accelerator.is_main_process: 1209 | unet = accelerator.unwrap_model(unet) 1210 | if args.use_ema: 1211 | ema_unet.copy_to(unet.parameters()) 1212 | 1213 | # Serialize pipeline. 1214 | vae = AutoencoderKL.from_pretrained( 1215 | vae_path, 1216 | subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, 1217 | revision=args.revision, 1218 | variant=args.variant, 1219 | torch_dtype=weight_dtype, 1220 | ) 1221 | pipeline = StableDiffusionXLPipeline.from_pretrained( 1222 | args.pretrained_model_name_or_path, 1223 | unet=unet, 1224 | vae=vae, 1225 | revision=args.revision, 1226 | variant=args.variant, 1227 | torch_dtype=weight_dtype, 1228 | ) 1229 | if args.prediction_type is not None: 1230 | scheduler_args = {"prediction_type": args.prediction_type} 1231 | pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args) 1232 | pipeline.save_pretrained(args.output_dir) 1233 | 1234 | # run inference 1235 | images = [] 1236 | if args.validation_prompt and args.num_validation_images > 0: 1237 | pipeline = pipeline.to(accelerator.device) 1238 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None 1239 | with torch.cuda.amp.autocast(): 1240 | images = [ 1241 | pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] 1242 | for _ in range(args.num_validation_images) 1243 | ] 1244 | 1245 | for tracker in accelerator.trackers: 1246 | if tracker.name == "tensorboard": 1247 | np_images = np.stack([np.asarray(img) for img in images]) 1248 | tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") 1249 | if tracker.name == "wandb": 1250 | tracker.log( 1251 | { 1252 | "test": [ 1253 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") 1254 | for i, image in enumerate(images) 1255 | ] 1256 | } 1257 | ) 1258 | 1259 | if args.push_to_hub: 1260 | save_model_card( 1261 | repo_id=repo_id, 1262 | images=images, 1263 | validation_prompt=args.validation_prompt, 1264 | base_model=args.pretrained_model_name_or_path, 1265 | dataset_name=args.dataset_name, 1266 | repo_folder=args.output_dir, 1267 | vae_path=args.pretrained_vae_model_name_or_path, 1268 | ) 1269 | upload_folder( 1270 | repo_id=repo_id, 1271 | folder_path=args.output_dir, 1272 | commit_message="End of training", 1273 | ignore_patterns=["step_*", "epoch_*"], 1274 | ) 1275 | 1276 | accelerator.end_training() 1277 | 1278 | 1279 | if __name__ == "__main__": 1280 | args = parse_args() 1281 | main(args) 1282 | -------------------------------------------------------------------------------- /code/others/train_dreambooth.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import copy 18 | import gc 19 | import importlib 20 | import itertools 21 | import logging 22 | import math 23 | import os 24 | import shutil 25 | import warnings 26 | from pathlib import Path 27 | import json 28 | import re 29 | 30 | import numpy as np 31 | import torch 32 | import torch.nn.functional as F 33 | import torch.utils.checkpoint 34 | import transformers 35 | from accelerate import Accelerator 36 | from accelerate.logging import get_logger 37 | from accelerate.utils import ProjectConfiguration, set_seed 38 | from huggingface_hub import create_repo, model_info, upload_folder 39 | from huggingface_hub.utils import insecure_hashlib 40 | from packaging import version 41 | from PIL import Image 42 | from PIL.ImageOps import exif_transpose 43 | from torch.utils.data import Dataset 44 | from torchvision import transforms 45 | from tqdm.auto import tqdm 46 | from transformers import AutoTokenizer, PretrainedConfig, CLIPProcessor, CLIPModel 47 | 48 | import diffusers 49 | from diffusers import ( 50 | AutoencoderKL, 51 | DDPMScheduler, 52 | DiffusionPipeline, 53 | StableDiffusionPipeline, 54 | UNet2DConditionModel, 55 | ) 56 | from diffusers.optimization import get_scheduler 57 | from diffusers.training_utils import compute_snr 58 | from diffusers.utils import check_min_version, is_wandb_available 59 | from diffusers.utils.import_utils import is_xformers_available 60 | 61 | from clip_score import get_score 62 | 63 | 64 | if is_wandb_available(): 65 | import wandb 66 | 67 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 68 | check_min_version("0.25.0.dev0") 69 | 70 | logger = get_logger(__name__) 71 | 72 | 73 | def save_model_card( 74 | repo_id: str, 75 | images=None, 76 | base_model=str, 77 | train_text_encoder=False, 78 | prompt=str, 79 | repo_folder=None, 80 | pipeline: DiffusionPipeline = None, 81 | ): 82 | img_str = "" 83 | for i, image in enumerate(images): 84 | image.save(os.path.join(repo_folder, f"image_{i}.png")) 85 | img_str += f"![img_{i}](./image_{i}.png)\n" 86 | 87 | yaml = f""" 88 | --- 89 | license: creativeml-openrail-m 90 | base_model: {base_model} 91 | instance_prompt: {prompt} 92 | tags: 93 | - {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'} 94 | - {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'} 95 | - text-to-image 96 | - diffusers 97 | - dreambooth 98 | inference: true 99 | --- 100 | """ 101 | model_card = f""" 102 | # DreamBooth - {repo_id} 103 | 104 | This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). 105 | You can find some example images in the following. \n 106 | {img_str} 107 | 108 | DreamBooth for the text encoder was enabled: {train_text_encoder}. 109 | """ 110 | with open(os.path.join(repo_folder, "README.md"), "w") as f: 111 | f.write(yaml + model_card) 112 | 113 | 114 | def log_validation( 115 | text_encoder, 116 | tokenizer, 117 | unet, 118 | vae, 119 | args, 120 | accelerator, 121 | weight_dtype, 122 | global_step, 123 | prompt_embeds, 124 | negative_prompt_embeds, 125 | ): 126 | logger.info( 127 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 128 | f" {args.validation_prompt}." 129 | ) 130 | 131 | pipeline_args = {} 132 | 133 | if vae is not None: 134 | pipeline_args["vae"] = vae 135 | 136 | if text_encoder is not None: 137 | text_encoder = accelerator.unwrap_model(text_encoder) 138 | 139 | # create pipeline (note: unet and vae are loaded again in float32) 140 | pipeline = DiffusionPipeline.from_pretrained( 141 | args.pretrained_model_name_or_path, 142 | tokenizer=tokenizer, 143 | text_encoder=text_encoder, 144 | unet=accelerator.unwrap_model(unet), 145 | revision=args.revision, 146 | variant=args.variant, 147 | torch_dtype=weight_dtype, 148 | **pipeline_args, 149 | ) 150 | 151 | # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it 152 | scheduler_args = {} 153 | 154 | if "variance_type" in pipeline.scheduler.config: 155 | variance_type = pipeline.scheduler.config.variance_type 156 | 157 | if variance_type in ["learned", "learned_range"]: 158 | variance_type = "fixed_small" 159 | 160 | scheduler_args["variance_type"] = variance_type 161 | 162 | module = importlib.import_module("diffusers") 163 | scheduler_class = getattr(module, args.validation_scheduler) 164 | pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args) 165 | pipeline = pipeline.to(accelerator.device) 166 | pipeline.set_progress_bar_config(disable=True) 167 | 168 | if args.pre_compute_text_embeddings: 169 | pipeline_args = { 170 | "prompt_embeds": prompt_embeds, 171 | "negative_prompt_embeds": negative_prompt_embeds, 172 | } 173 | else: 174 | pipeline_args = {"prompt": args.validation_prompt} 175 | 176 | # run inference 177 | generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed) 178 | images = [] 179 | if args.validation_images is None: 180 | for _ in range(args.num_validation_images): 181 | with torch.autocast("cuda"): 182 | image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0] 183 | images.append(image) 184 | else: 185 | for image in args.validation_images: 186 | image = Image.open(image) 187 | image = pipeline(**pipeline_args, image=image, generator=generator).images[0] 188 | images.append(image) 189 | 190 | for tracker in accelerator.trackers: 191 | if tracker.name == "tensorboard": 192 | np_images = np.stack([np.asarray(img) for img in images]) 193 | tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC") 194 | if tracker.name == "wandb": 195 | tracker.log( 196 | { 197 | "validation": [ 198 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) 199 | ] 200 | } 201 | ) 202 | 203 | del pipeline 204 | torch.cuda.empty_cache() 205 | 206 | return images 207 | 208 | 209 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): 210 | text_encoder_config = PretrainedConfig.from_pretrained( 211 | pretrained_model_name_or_path, 212 | subfolder="text_encoder", 213 | revision=revision, 214 | ) 215 | model_class = text_encoder_config.architectures[0] 216 | 217 | if model_class == "CLIPTextModel": 218 | from transformers import CLIPTextModel 219 | 220 | return CLIPTextModel 221 | elif model_class == "RobertaSeriesModelWithTransformation": 222 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation 223 | 224 | return RobertaSeriesModelWithTransformation 225 | elif model_class == "T5EncoderModel": 226 | from transformers import T5EncoderModel 227 | 228 | return T5EncoderModel 229 | else: 230 | raise ValueError(f"{model_class} is not supported.") 231 | 232 | 233 | def parse_args(input_args=None): 234 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 235 | parser.add_argument( 236 | "--pretrained_model_name_or_path", 237 | type=str, 238 | default=None, 239 | required=True, 240 | help="Path to pretrained model or model identifier from huggingface.co/models.", 241 | ) 242 | parser.add_argument( 243 | "--clip_model_name_or_path", 244 | type=str, 245 | default=None, 246 | required=True, 247 | help="Path to clip model or model identifier from huggingface.co/models.", 248 | ) 249 | parser.add_argument( 250 | "--revision", 251 | type=str, 252 | default=None, 253 | required=False, 254 | help="Revision of pretrained model identifier from huggingface.co/models.", 255 | ) 256 | parser.add_argument( 257 | "--variant", 258 | type=str, 259 | default=None, 260 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 261 | ) 262 | parser.add_argument( 263 | "--tokenizer_name", 264 | type=str, 265 | default=None, 266 | help="Pretrained tokenizer name or path if not the same as model_name", 267 | ) 268 | parser.add_argument( 269 | "--instance_data_dir", 270 | type=str, 271 | default=None, 272 | required=True, 273 | help="A folder containing the training data of instance images.", 274 | ) 275 | parser.add_argument( 276 | "--class_data_dir", 277 | type=str, 278 | default=None, 279 | required=False, 280 | help="A folder containing the training data of class images.", 281 | ) 282 | parser.add_argument( 283 | "--instance_prompt", 284 | type=str, 285 | default=None, 286 | required=True, 287 | help="The prompt with identifier specifying the instance", 288 | ) 289 | parser.add_argument( 290 | "--class_prompt", 291 | type=str, 292 | default=None, 293 | help="The prompt to specify images in the same class as provided instance images.", 294 | ) 295 | parser.add_argument( 296 | "--class_prompts_json_path", 297 | type=str, 298 | default=None, 299 | ) 300 | parser.add_argument( 301 | "--with_prior_preservation", 302 | default=False, 303 | action="store_true", 304 | help="Flag to add prior preservation loss.", 305 | ) 306 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") 307 | parser.add_argument( 308 | "--num_class_images", 309 | type=int, 310 | default=100, 311 | help=( 312 | "Minimal class images for prior preservation loss. If there are not enough images already present in" 313 | " class_data_dir, additional images will be sampled with class_prompt." 314 | ), 315 | ) 316 | parser.add_argument( 317 | "--output_dir", 318 | type=str, 319 | default="dreambooth-model", 320 | help="The output directory where the model predictions and checkpoints will be written.", 321 | ) 322 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 323 | parser.add_argument( 324 | "--resolution", 325 | type=int, 326 | default=512, 327 | help=( 328 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 329 | " resolution" 330 | ), 331 | ) 332 | parser.add_argument( 333 | "--center_crop", 334 | default=False, 335 | action="store_true", 336 | help=( 337 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 338 | " cropped. The images will be resized to the resolution first before cropping." 339 | ), 340 | ) 341 | parser.add_argument( 342 | "--train_text_encoder", 343 | action="store_true", 344 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", 345 | ) 346 | parser.add_argument( 347 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 348 | ) 349 | parser.add_argument( 350 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 351 | ) 352 | parser.add_argument("--num_train_epochs", type=int, default=1) 353 | parser.add_argument( 354 | "--max_train_steps", 355 | type=int, 356 | default=None, 357 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 358 | ) 359 | parser.add_argument( 360 | "--checkpointing_steps", 361 | type=int, 362 | default=500, 363 | help=( 364 | "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " 365 | "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." 366 | "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." 367 | "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" 368 | "instructions." 369 | ), 370 | ) 371 | parser.add_argument( 372 | "--checkpoints_total_limit", 373 | type=int, 374 | default=None, 375 | help=( 376 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." 377 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" 378 | " for more details" 379 | ), 380 | ) 381 | parser.add_argument( 382 | "--resume_from_checkpoint", 383 | type=str, 384 | default=None, 385 | help=( 386 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 387 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 388 | ), 389 | ) 390 | parser.add_argument( 391 | "--gradient_accumulation_steps", 392 | type=int, 393 | default=1, 394 | help="Number of updates steps to accumulate before performing a backward/update pass.", 395 | ) 396 | parser.add_argument( 397 | "--gradient_checkpointing", 398 | action="store_true", 399 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 400 | ) 401 | parser.add_argument( 402 | "--learning_rate", 403 | type=float, 404 | default=5e-6, 405 | help="Initial learning rate (after the potential warmup period) to use.", 406 | ) 407 | parser.add_argument( 408 | "--scale_lr", 409 | action="store_true", 410 | default=False, 411 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 412 | ) 413 | parser.add_argument( 414 | "--lr_scheduler", 415 | type=str, 416 | default="constant", 417 | help=( 418 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 419 | ' "constant", "constant_with_warmup"]' 420 | ), 421 | ) 422 | parser.add_argument( 423 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 424 | ) 425 | parser.add_argument( 426 | "--lr_num_cycles", 427 | type=int, 428 | default=1, 429 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 430 | ) 431 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 432 | parser.add_argument( 433 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 434 | ) 435 | parser.add_argument( 436 | "--dataloader_num_workers", 437 | type=int, 438 | default=0, 439 | help=( 440 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 441 | ), 442 | ) 443 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 444 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 445 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 446 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 447 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 448 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 449 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 450 | parser.add_argument( 451 | "--hub_model_id", 452 | type=str, 453 | default=None, 454 | help="The name of the repository to keep in sync with the local `output_dir`.", 455 | ) 456 | parser.add_argument( 457 | "--logging_dir", 458 | type=str, 459 | default="logs", 460 | help=( 461 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 462 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 463 | ), 464 | ) 465 | parser.add_argument( 466 | "--allow_tf32", 467 | action="store_true", 468 | help=( 469 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 470 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 471 | ), 472 | ) 473 | parser.add_argument( 474 | "--report_to", 475 | type=str, 476 | default="tensorboard", 477 | help=( 478 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 479 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 480 | ), 481 | ) 482 | parser.add_argument( 483 | "--validation_prompt", 484 | type=str, 485 | default=None, 486 | help="A prompt that is used during validation to verify that the model is learning.", 487 | ) 488 | parser.add_argument( 489 | "--validation_prompts_path", 490 | type=str, 491 | default=None 492 | ) 493 | parser.add_argument( 494 | "--validation_batch_size", 495 | type=int 496 | ) 497 | parser.add_argument( 498 | "--validation_output_dir", 499 | type=str, 500 | default=None, 501 | help="The output directory where images will be written.", 502 | ) 503 | parser.add_argument( 504 | "--num_prompts_per_group", 505 | type=int 506 | ) 507 | parser.add_argument( 508 | "--num_groups", 509 | type=int 510 | ) 511 | parser.add_argument( 512 | "--num_images_per_prompt", 513 | type=int 514 | ) 515 | parser.add_argument( 516 | "--num_validation_images", 517 | type=int, 518 | default=4, 519 | help="Number of images that should be generated during validation with `validation_prompt`.", 520 | ) 521 | parser.add_argument( 522 | "--validation_steps", 523 | type=int, 524 | default=100, 525 | help=( 526 | "Run validation every X steps. Validation consists of running the prompt" 527 | " `args.validation_prompt` multiple times: `args.num_validation_images`" 528 | " and logging the images." 529 | ), 530 | ) 531 | parser.add_argument( 532 | "--mixed_precision", 533 | type=str, 534 | default=None, 535 | choices=["no", "fp16", "bf16"], 536 | help=( 537 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 538 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 539 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 540 | ), 541 | ) 542 | parser.add_argument( 543 | "--prior_generation_precision", 544 | type=str, 545 | default=None, 546 | choices=["no", "fp32", "fp16", "bf16"], 547 | help=( 548 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 549 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." 550 | ), 551 | ) 552 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 553 | parser.add_argument( 554 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 555 | ) 556 | parser.add_argument( 557 | "--set_grads_to_none", 558 | action="store_true", 559 | help=( 560 | "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" 561 | " behaviors, so disable this argument if it causes any problems. More info:" 562 | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" 563 | ), 564 | ) 565 | 566 | parser.add_argument( 567 | "--offset_noise", 568 | action="store_true", 569 | default=False, 570 | help=( 571 | "Fine-tuning against a modified noise" 572 | " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information." 573 | ), 574 | ) 575 | parser.add_argument( 576 | "--snr_gamma", 577 | type=float, 578 | default=None, 579 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " 580 | "More details here: https://arxiv.org/abs/2303.09556.", 581 | ) 582 | parser.add_argument( 583 | "--pre_compute_text_embeddings", 584 | action="store_true", 585 | help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.", 586 | ) 587 | parser.add_argument( 588 | "--tokenizer_max_length", 589 | type=int, 590 | default=None, 591 | required=False, 592 | help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.", 593 | ) 594 | parser.add_argument( 595 | "--text_encoder_use_attention_mask", 596 | action="store_true", 597 | required=False, 598 | help="Whether to use attention mask for the text encoder", 599 | ) 600 | parser.add_argument( 601 | "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder" 602 | ) 603 | parser.add_argument( 604 | "--validation_images", 605 | required=False, 606 | default=None, 607 | nargs="+", 608 | help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.", 609 | ) 610 | parser.add_argument( 611 | "--class_labels_conditioning", 612 | required=False, 613 | default=None, 614 | help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.", 615 | ) 616 | parser.add_argument( 617 | "--validation_scheduler", 618 | type=str, 619 | default="DPMSolverMultistepScheduler", 620 | choices=["DPMSolverMultistepScheduler", "DDPMScheduler"], 621 | help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.", 622 | ) 623 | 624 | if input_args is not None: 625 | args = parser.parse_args(input_args) 626 | else: 627 | args = parser.parse_args() 628 | 629 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 630 | if env_local_rank != -1 and env_local_rank != args.local_rank: 631 | args.local_rank = env_local_rank 632 | 633 | if args.with_prior_preservation: 634 | if args.class_data_dir is None: 635 | raise ValueError("You must specify a data directory for class images.") 636 | if args.class_prompt is None and args.class_prompts_json_path is None: 637 | raise ValueError("You must specify prompt for class images.") 638 | else: 639 | # logger is not available yet 640 | if args.class_data_dir is not None: 641 | warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") 642 | if args.class_prompt is not None: 643 | warnings.warn("You need not use --class_prompt without --with_prior_preservation.") 644 | 645 | if args.train_text_encoder and args.pre_compute_text_embeddings: 646 | raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`") 647 | 648 | return args 649 | 650 | 651 | class DreamBoothDataset(Dataset): 652 | """ 653 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 654 | It pre-processes the images and the tokenizes prompts. 655 | """ 656 | 657 | def __init__( 658 | self, 659 | instance_data_root, 660 | instance_prompt, 661 | tokenizer, 662 | class_data_root=None, 663 | class_prompt=None, 664 | class_num=None, 665 | size=512, 666 | center_crop=False, 667 | encoder_hidden_states=None, 668 | class_prompt_encoder_hidden_states=None, 669 | tokenizer_max_length=None, 670 | class_prompts_json_path=None 671 | ): 672 | self.size = size 673 | self.center_crop = center_crop 674 | self.tokenizer = tokenizer 675 | self.encoder_hidden_states = encoder_hidden_states 676 | self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states 677 | self.tokenizer_max_length = tokenizer_max_length 678 | 679 | self.instance_data_root = Path(instance_data_root) 680 | if not self.instance_data_root.exists(): 681 | raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.") 682 | 683 | self.instance_images_path = [p for p in Path(instance_data_root).iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}] 684 | self.num_instance_images = len(self.instance_images_path) 685 | self.instance_prompt = instance_prompt 686 | self._length = self.num_instance_images 687 | 688 | if class_data_root is not None: 689 | self.class_data_root = Path(class_data_root) 690 | self.class_data_root.mkdir(parents=True, exist_ok=True) 691 | self.class_images_path = [p for p in self.class_data_root.iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}] 692 | if class_num is not None: 693 | self.num_class_images = min(len(self.class_images_path), class_num) 694 | else: 695 | self.num_class_images = len(self.class_images_path) 696 | self._length = max(self.num_class_images, self.num_instance_images) 697 | self.class_prompts = [class_prompt] 698 | if class_prompts_json_path: 699 | self.class_images_path.sort(key=lambda x:int(re.findall(r'\d+',x.name)[-1])) 700 | self.class_prompts = json.load(open(class_prompts_json_path,'r',encoding='utf-8')) 701 | self.num_class_prompts = len(self.class_prompts) 702 | else: 703 | self.class_data_root = None 704 | 705 | self.image_transforms = transforms.Compose( 706 | [ 707 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 708 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 709 | transforms.ToTensor(), 710 | transforms.Normalize([0.5], [0.5]), 711 | ] 712 | ) 713 | 714 | def __len__(self): 715 | return self._length 716 | 717 | def __getitem__(self, index): 718 | example = {} 719 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) 720 | instance_image = exif_transpose(instance_image) 721 | 722 | if not instance_image.mode == "RGB": 723 | instance_image = instance_image.convert("RGB") 724 | example["instance_images"] = self.image_transforms(instance_image) 725 | 726 | if self.encoder_hidden_states is not None: 727 | example["instance_prompt_ids"] = self.encoder_hidden_states 728 | else: 729 | text_inputs = tokenize_prompt( 730 | self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length 731 | ) 732 | example["instance_prompt_ids"] = text_inputs.input_ids 733 | example["instance_attention_mask"] = text_inputs.attention_mask 734 | 735 | if self.class_data_root: 736 | class_image = Image.open(self.class_images_path[index % self.num_class_images]) 737 | class_image = exif_transpose(class_image) 738 | class_prompt = self.class_prompts[index % self.num_class_images % self.num_class_prompts] 739 | 740 | if not class_image.mode == "RGB": 741 | class_image = class_image.convert("RGB") 742 | example["class_images"] = self.image_transforms(class_image) 743 | 744 | if self.class_prompt_encoder_hidden_states is not None: 745 | example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states 746 | else: 747 | class_text_inputs = tokenize_prompt( 748 | self.tokenizer, class_prompt, tokenizer_max_length=self.tokenizer_max_length 749 | ) 750 | example["class_prompt_ids"] = class_text_inputs.input_ids 751 | example["class_attention_mask"] = class_text_inputs.attention_mask 752 | 753 | return example 754 | 755 | 756 | def collate_fn(examples, with_prior_preservation=False): 757 | has_attention_mask = "instance_attention_mask" in examples[0] 758 | 759 | input_ids = [example["instance_prompt_ids"] for example in examples] 760 | pixel_values = [example["instance_images"] for example in examples] 761 | 762 | if has_attention_mask: 763 | attention_mask = [example["instance_attention_mask"] for example in examples] 764 | 765 | # Concat class and instance examples for prior preservation. 766 | # We do this to avoid doing two forward passes. 767 | if with_prior_preservation: 768 | input_ids += [example["class_prompt_ids"] for example in examples] 769 | pixel_values += [example["class_images"] for example in examples] 770 | 771 | if has_attention_mask: 772 | attention_mask += [example["class_attention_mask"] for example in examples] 773 | 774 | pixel_values = torch.stack(pixel_values) 775 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 776 | 777 | input_ids = torch.cat(input_ids, dim=0) 778 | 779 | batch = { 780 | "input_ids": input_ids, 781 | "pixel_values": pixel_values, 782 | } 783 | 784 | if has_attention_mask: 785 | attention_mask = torch.cat(attention_mask, dim=0) 786 | batch["attention_mask"] = attention_mask 787 | 788 | return batch 789 | 790 | 791 | class PromptDataset(Dataset): 792 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 793 | 794 | def __init__(self, prompt, num_samples): 795 | self.prompt = prompt 796 | self.num_samples = num_samples 797 | 798 | def __len__(self): 799 | return self.num_samples 800 | 801 | def __getitem__(self, index): 802 | example = {} 803 | example["prompt"] = self.prompt 804 | example["index"] = index 805 | return example 806 | 807 | 808 | def model_has_vae(args): 809 | config_file_name = os.path.join("vae", AutoencoderKL.config_name) 810 | if os.path.isdir(args.pretrained_model_name_or_path): 811 | config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name) 812 | return os.path.isfile(config_file_name) 813 | else: 814 | files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings 815 | return any(file.rfilename == config_file_name for file in files_in_repo) 816 | 817 | 818 | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None): 819 | if tokenizer_max_length is not None: 820 | max_length = tokenizer_max_length 821 | else: 822 | max_length = tokenizer.model_max_length 823 | 824 | text_inputs = tokenizer( 825 | prompt, 826 | truncation=True, 827 | padding="max_length", 828 | max_length=max_length, 829 | return_tensors="pt", 830 | ) 831 | 832 | return text_inputs 833 | 834 | 835 | def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None): 836 | text_input_ids = input_ids.to(text_encoder.device) 837 | 838 | if text_encoder_use_attention_mask: 839 | attention_mask = attention_mask.to(text_encoder.device) 840 | else: 841 | attention_mask = None 842 | 843 | prompt_embeds = text_encoder( 844 | text_input_ids, 845 | attention_mask=attention_mask, 846 | ) 847 | prompt_embeds = prompt_embeds[0] 848 | 849 | return prompt_embeds 850 | 851 | 852 | def main(args): 853 | logging_dir = Path(args.output_dir, args.logging_dir) 854 | 855 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 856 | 857 | accelerator = Accelerator( 858 | gradient_accumulation_steps=args.gradient_accumulation_steps, 859 | mixed_precision=args.mixed_precision, 860 | log_with=args.report_to, 861 | project_config=accelerator_project_config, 862 | ) 863 | 864 | if args.report_to == "wandb": 865 | if not is_wandb_available(): 866 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 867 | 868 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 869 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 870 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 871 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 872 | raise ValueError( 873 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 874 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 875 | ) 876 | 877 | # Make one log on every process with the configuration for debugging. 878 | logging.basicConfig( 879 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 880 | datefmt="%m/%d/%Y %H:%M:%S", 881 | level=logging.INFO, 882 | ) 883 | logger.info(accelerator.state, main_process_only=False) 884 | if accelerator.is_local_main_process: 885 | transformers.utils.logging.set_verbosity_warning() 886 | diffusers.utils.logging.set_verbosity_info() 887 | else: 888 | transformers.utils.logging.set_verbosity_error() 889 | diffusers.utils.logging.set_verbosity_error() 890 | 891 | # If passed along, set the training seed now. 892 | if args.seed is not None: 893 | set_seed(args.seed) 894 | 895 | # Generate class images if prior preservation is enabled. 896 | if args.with_prior_preservation: 897 | class_images_dir = Path(args.class_data_dir) 898 | if not class_images_dir.exists(): 899 | class_images_dir.mkdir(parents=True) 900 | cur_class_images = len(list(class_images_dir.iterdir())) 901 | 902 | if cur_class_images < args.num_class_images: 903 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 904 | if args.prior_generation_precision == "fp32": 905 | torch_dtype = torch.float32 906 | elif args.prior_generation_precision == "fp16": 907 | torch_dtype = torch.float16 908 | elif args.prior_generation_precision == "bf16": 909 | torch_dtype = torch.bfloat16 910 | pipeline = DiffusionPipeline.from_pretrained( 911 | args.pretrained_model_name_or_path, 912 | torch_dtype=torch_dtype, 913 | safety_checker=None, 914 | revision=args.revision, 915 | variant=args.variant, 916 | ) 917 | pipeline.set_progress_bar_config(disable=True) 918 | 919 | num_new_images = args.num_class_images - cur_class_images 920 | logger.info(f"Number of class images to sample: {num_new_images}.") 921 | 922 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) 923 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) 924 | 925 | sample_dataloader = accelerator.prepare(sample_dataloader) 926 | pipeline.to(accelerator.device) 927 | 928 | for example in tqdm( 929 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process 930 | ): 931 | images = pipeline(example["prompt"]).images 932 | 933 | for i, image in enumerate(images): 934 | hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() 935 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" 936 | image.save(image_filename) 937 | 938 | del pipeline 939 | if torch.cuda.is_available(): 940 | torch.cuda.empty_cache() 941 | 942 | # Handle the repository creation 943 | if accelerator.is_main_process: 944 | if args.output_dir is not None: 945 | os.makedirs(args.output_dir, exist_ok=True) 946 | 947 | if args.push_to_hub: 948 | repo_id = create_repo( 949 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 950 | ).repo_id 951 | 952 | # Load the tokenizer 953 | if args.tokenizer_name: 954 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) 955 | elif args.pretrained_model_name_or_path: 956 | tokenizer = AutoTokenizer.from_pretrained( 957 | args.pretrained_model_name_or_path, 958 | subfolder="tokenizer", 959 | revision=args.revision, 960 | use_fast=False, 961 | ) 962 | 963 | # import correct text encoder class 964 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) 965 | 966 | # Load scheduler and models 967 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 968 | text_encoder = text_encoder_cls.from_pretrained( 969 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant 970 | ) 971 | 972 | if args.clip_model_name_or_path: 973 | clip_model = CLIPModel.from_pretrained(args.clip_model_name_or_path) 974 | clip_processor = CLIPProcessor.from_pretrained(args.clip_model_name_or_path) 975 | 976 | if model_has_vae(args): 977 | vae = AutoencoderKL.from_pretrained( 978 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant 979 | ) 980 | else: 981 | vae = None 982 | 983 | unet = UNet2DConditionModel.from_pretrained( 984 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant 985 | ) 986 | 987 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 988 | def save_model_hook(models, weights, output_dir): 989 | if accelerator.is_main_process: 990 | for model in models: 991 | sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder" 992 | model.save_pretrained(os.path.join(output_dir, sub_dir)) 993 | 994 | # make sure to pop weight so that corresponding model is not saved again 995 | weights.pop() 996 | 997 | def load_model_hook(models, input_dir): 998 | while len(models) > 0: 999 | # pop models so that they are not loaded again 1000 | model = models.pop() 1001 | 1002 | if isinstance(model, type(accelerator.unwrap_model(text_encoder))): 1003 | # load transformers style into model 1004 | load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") 1005 | model.config = load_model.config 1006 | else: 1007 | # load diffusers style into model 1008 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") 1009 | model.register_to_config(**load_model.config) 1010 | 1011 | model.load_state_dict(load_model.state_dict()) 1012 | del load_model 1013 | 1014 | accelerator.register_save_state_pre_hook(save_model_hook) 1015 | accelerator.register_load_state_pre_hook(load_model_hook) 1016 | 1017 | if vae is not None: 1018 | vae.requires_grad_(False) 1019 | 1020 | if not args.train_text_encoder: 1021 | text_encoder.requires_grad_(False) 1022 | 1023 | if args.enable_xformers_memory_efficient_attention: 1024 | if is_xformers_available(): 1025 | import xformers 1026 | 1027 | xformers_version = version.parse(xformers.__version__) 1028 | if xformers_version == version.parse("0.0.16"): 1029 | logger.warn( 1030 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 1031 | ) 1032 | unet.enable_xformers_memory_efficient_attention() 1033 | else: 1034 | raise ValueError("xformers is not available. Make sure it is installed correctly") 1035 | 1036 | if args.gradient_checkpointing: 1037 | unet.enable_gradient_checkpointing() 1038 | if args.train_text_encoder: 1039 | text_encoder.gradient_checkpointing_enable() 1040 | 1041 | # Check that all trainable models are in full precision 1042 | low_precision_error_string = ( 1043 | "Please make sure to always have all model weights in full float32 precision when starting training - even if" 1044 | " doing mixed precision training. copy of the weights should still be float32." 1045 | ) 1046 | 1047 | if accelerator.unwrap_model(unet).dtype != torch.float32: 1048 | raise ValueError( 1049 | f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" 1050 | ) 1051 | 1052 | if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32: 1053 | raise ValueError( 1054 | f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." 1055 | f" {low_precision_error_string}" 1056 | ) 1057 | 1058 | # Enable TF32 for faster training on Ampere GPUs, 1059 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 1060 | if args.allow_tf32: 1061 | torch.backends.cuda.matmul.allow_tf32 = True 1062 | 1063 | if args.scale_lr: 1064 | args.learning_rate = ( 1065 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 1066 | ) 1067 | 1068 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 1069 | if args.use_8bit_adam: 1070 | try: 1071 | import bitsandbytes as bnb 1072 | except ImportError: 1073 | raise ImportError( 1074 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 1075 | ) 1076 | 1077 | optimizer_class = bnb.optim.AdamW8bit 1078 | else: 1079 | optimizer_class = torch.optim.AdamW 1080 | 1081 | # Optimizer creation 1082 | params_to_optimize = ( 1083 | itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() 1084 | ) 1085 | optimizer = optimizer_class( 1086 | params_to_optimize, 1087 | lr=args.learning_rate, 1088 | betas=(args.adam_beta1, args.adam_beta2), 1089 | weight_decay=args.adam_weight_decay, 1090 | eps=args.adam_epsilon, 1091 | ) 1092 | 1093 | if args.pre_compute_text_embeddings: 1094 | 1095 | def compute_text_embeddings(prompt): 1096 | with torch.no_grad(): 1097 | text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length) 1098 | prompt_embeds = encode_prompt( 1099 | text_encoder, 1100 | text_inputs.input_ids, 1101 | text_inputs.attention_mask, 1102 | text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, 1103 | ) 1104 | 1105 | return prompt_embeds 1106 | 1107 | pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt) 1108 | validation_prompt_negative_prompt_embeds = compute_text_embeddings("") 1109 | 1110 | if args.validation_prompt is not None: 1111 | validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt) 1112 | else: 1113 | validation_prompt_encoder_hidden_states = None 1114 | 1115 | if args.class_prompt is not None: 1116 | pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt) 1117 | else: 1118 | pre_computed_class_prompt_encoder_hidden_states = None 1119 | 1120 | text_encoder = None 1121 | tokenizer = None 1122 | 1123 | gc.collect() 1124 | torch.cuda.empty_cache() 1125 | else: 1126 | pre_computed_encoder_hidden_states = None 1127 | validation_prompt_encoder_hidden_states = None 1128 | validation_prompt_negative_prompt_embeds = None 1129 | pre_computed_class_prompt_encoder_hidden_states = None 1130 | 1131 | # Dataset and DataLoaders creation: 1132 | train_dataset = DreamBoothDataset( 1133 | instance_data_root=args.instance_data_dir, 1134 | instance_prompt=args.instance_prompt, 1135 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, 1136 | class_prompt=args.class_prompt, 1137 | class_num=args.num_class_images, 1138 | tokenizer=tokenizer, 1139 | size=args.resolution, 1140 | center_crop=args.center_crop, 1141 | encoder_hidden_states=pre_computed_encoder_hidden_states, 1142 | class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states, 1143 | tokenizer_max_length=args.tokenizer_max_length, 1144 | class_prompts_json_path=args.class_prompts_json_path 1145 | ) 1146 | 1147 | train_dataloader = torch.utils.data.DataLoader( 1148 | train_dataset, 1149 | batch_size=args.train_batch_size, 1150 | shuffle=True, 1151 | collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), 1152 | num_workers=args.dataloader_num_workers, 1153 | ) 1154 | 1155 | # Scheduler and math around the number of training steps. 1156 | overrode_max_train_steps = False 1157 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 1158 | if args.max_train_steps is None: 1159 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 1160 | overrode_max_train_steps = True 1161 | 1162 | lr_scheduler = get_scheduler( 1163 | args.lr_scheduler, 1164 | optimizer=optimizer, 1165 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 1166 | num_training_steps=args.max_train_steps * accelerator.num_processes, 1167 | num_cycles=args.lr_num_cycles, 1168 | power=args.lr_power, 1169 | ) 1170 | 1171 | # Prepare everything with our `accelerator`. 1172 | if args.train_text_encoder: 1173 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 1174 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler 1175 | ) 1176 | else: 1177 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 1178 | unet, optimizer, train_dataloader, lr_scheduler 1179 | ) 1180 | 1181 | # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision 1182 | # as these weights are only used for inference, keeping weights in full precision is not required. 1183 | weight_dtype = torch.float32 1184 | if accelerator.mixed_precision == "fp16": 1185 | weight_dtype = torch.float16 1186 | elif accelerator.mixed_precision == "bf16": 1187 | weight_dtype = torch.bfloat16 1188 | 1189 | # Move vae and text_encoder to device and cast to weight_dtype 1190 | if vae is not None: 1191 | vae.to(accelerator.device, dtype=weight_dtype) 1192 | 1193 | if not args.train_text_encoder and text_encoder is not None: 1194 | text_encoder.to(accelerator.device, dtype=weight_dtype) 1195 | 1196 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 1197 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 1198 | if overrode_max_train_steps: 1199 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 1200 | # Afterwards we recalculate our number of training epochs 1201 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 1202 | 1203 | # We need to initialize the trackers we use, and also store our configuration. 1204 | # The trackers initializes automatically on the main process. 1205 | if accelerator.is_main_process: 1206 | tracker_config = vars(copy.deepcopy(args)) 1207 | tracker_config.pop("validation_images") 1208 | accelerator.init_trackers("dreambooth", config=tracker_config) 1209 | 1210 | # Train! 1211 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 1212 | 1213 | logger.info("***** Running training *****") 1214 | logger.info(f" Num examples = {len(train_dataset)}") 1215 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 1216 | logger.info(f" Num Epochs = {args.num_train_epochs}") 1217 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 1218 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 1219 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 1220 | logger.info(f" Total optimization steps = {args.max_train_steps}") 1221 | global_step = 0 1222 | first_epoch = 0 1223 | 1224 | # Potentially load in the weights and states from a previous save 1225 | if args.resume_from_checkpoint: 1226 | if args.resume_from_checkpoint != "latest": 1227 | path = os.path.basename(args.resume_from_checkpoint) 1228 | else: 1229 | # Get the most recent checkpoint 1230 | dirs = os.listdir(args.output_dir) 1231 | dirs = [d for d in dirs if d.startswith("checkpoint")] 1232 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 1233 | path = dirs[-1] if len(dirs) > 0 else None 1234 | 1235 | if path is None: 1236 | accelerator.print( 1237 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 1238 | ) 1239 | args.resume_from_checkpoint = None 1240 | initial_global_step = 0 1241 | else: 1242 | accelerator.print(f"Resuming from checkpoint {path}") 1243 | accelerator.load_state(os.path.join(args.output_dir, path)) 1244 | global_step = int(path.split("-")[1]) 1245 | 1246 | initial_global_step = global_step 1247 | first_epoch = global_step // num_update_steps_per_epoch 1248 | else: 1249 | initial_global_step = 0 1250 | 1251 | progress_bar = tqdm( 1252 | range(0, args.max_train_steps), 1253 | initial=initial_global_step, 1254 | desc="Steps", 1255 | # Only show the progress bar once on each machine. 1256 | disable=not accelerator.is_local_main_process, 1257 | ) 1258 | 1259 | for epoch in range(first_epoch, args.num_train_epochs): 1260 | unet.train() 1261 | if args.train_text_encoder: 1262 | text_encoder.train() 1263 | for step, batch in enumerate(train_dataloader): 1264 | with accelerator.accumulate(unet): 1265 | pixel_values = batch["pixel_values"].to(dtype=weight_dtype) 1266 | 1267 | if vae is not None: 1268 | # Convert images to latent space 1269 | model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 1270 | model_input = model_input * vae.config.scaling_factor 1271 | else: 1272 | model_input = pixel_values 1273 | 1274 | # Sample noise that we'll add to the model input 1275 | if args.offset_noise: 1276 | noise = torch.randn_like(model_input) + 0.1 * torch.randn( 1277 | model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device 1278 | ) 1279 | else: 1280 | noise = torch.randn_like(model_input) 1281 | bsz, channels, height, width = model_input.shape 1282 | # Sample a random timestep for each image 1283 | timesteps = torch.randint( 1284 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device 1285 | ) 1286 | timesteps = timesteps.long() 1287 | 1288 | # Add noise to the model input according to the noise magnitude at each timestep 1289 | # (this is the forward diffusion process) 1290 | noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) 1291 | 1292 | # Get the text embedding for conditioning 1293 | if args.pre_compute_text_embeddings: 1294 | encoder_hidden_states = batch["input_ids"] 1295 | else: 1296 | encoder_hidden_states = encode_prompt( 1297 | text_encoder, 1298 | batch["input_ids"], 1299 | batch["attention_mask"], 1300 | text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, 1301 | ) 1302 | 1303 | if accelerator.unwrap_model(unet).config.in_channels == channels * 2: 1304 | noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) 1305 | 1306 | if args.class_labels_conditioning == "timesteps": 1307 | class_labels = timesteps 1308 | else: 1309 | class_labels = None 1310 | 1311 | # Predict the noise residual 1312 | model_pred = unet( 1313 | noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels 1314 | ).sample 1315 | 1316 | if model_pred.shape[1] == 6: 1317 | model_pred, _ = torch.chunk(model_pred, 2, dim=1) 1318 | 1319 | # Get the target for loss depending on the prediction type 1320 | if noise_scheduler.config.prediction_type == "epsilon": 1321 | target = noise 1322 | elif noise_scheduler.config.prediction_type == "v_prediction": 1323 | target = noise_scheduler.get_velocity(model_input, noise, timesteps) 1324 | else: 1325 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 1326 | 1327 | if args.with_prior_preservation: 1328 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 1329 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 1330 | target, target_prior = torch.chunk(target, 2, dim=0) 1331 | # Compute prior loss 1332 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") 1333 | 1334 | # Compute instance loss 1335 | if args.snr_gamma is None: 1336 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 1337 | else: 1338 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. 1339 | # Since we predict the noise instead of x_0, the original formulation is slightly changed. 1340 | # This is discussed in Section 4.2 of the same paper. 1341 | snr = compute_snr(noise_scheduler, timesteps) 1342 | base_weight = ( 1343 | torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr 1344 | ) 1345 | 1346 | if noise_scheduler.config.prediction_type == "v_prediction": 1347 | # Velocity objective needs to be floored to an SNR weight of one. 1348 | mse_loss_weights = base_weight + 1 1349 | else: 1350 | # Epsilon and sample both use the same loss weights. 1351 | mse_loss_weights = base_weight 1352 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 1353 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights 1354 | loss = loss.mean() 1355 | 1356 | if args.with_prior_preservation: 1357 | # Add the prior loss to the instance loss. 1358 | loss = loss + args.prior_loss_weight * prior_loss 1359 | 1360 | accelerator.backward(loss) 1361 | if accelerator.sync_gradients: 1362 | params_to_clip = ( 1363 | itertools.chain(unet.parameters(), text_encoder.parameters()) 1364 | if args.train_text_encoder 1365 | else unet.parameters() 1366 | ) 1367 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 1368 | optimizer.step() 1369 | lr_scheduler.step() 1370 | optimizer.zero_grad(set_to_none=args.set_grads_to_none) 1371 | 1372 | # Checks if the accelerator has performed an optimization step behind the scenes 1373 | if accelerator.sync_gradients: 1374 | progress_bar.update(1) 1375 | global_step += 1 1376 | 1377 | if accelerator.is_main_process: 1378 | if global_step % args.checkpointing_steps == 0: 1379 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 1380 | if args.checkpoints_total_limit is not None: 1381 | checkpoints = os.listdir(args.output_dir) 1382 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 1383 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 1384 | 1385 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 1386 | if len(checkpoints) >= args.checkpoints_total_limit: 1387 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 1388 | removing_checkpoints = checkpoints[0:num_to_remove] 1389 | 1390 | logger.info( 1391 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 1392 | ) 1393 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 1394 | 1395 | for removing_checkpoint in removing_checkpoints: 1396 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 1397 | shutil.rmtree(removing_checkpoint) 1398 | 1399 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 1400 | accelerator.save_state(save_path) 1401 | logger.info(f"Saved state to {save_path}") 1402 | 1403 | images = [] 1404 | 1405 | if args.validation_prompt is not None and global_step % args.validation_steps == 0: 1406 | images = log_validation( 1407 | text_encoder, 1408 | tokenizer, 1409 | unet, 1410 | vae, 1411 | args, 1412 | accelerator, 1413 | weight_dtype, 1414 | global_step, 1415 | validation_prompt_encoder_hidden_states, 1416 | validation_prompt_negative_prompt_embeds, 1417 | ) 1418 | 1419 | if args.validation_prompts_path is not None and global_step % args.validation_steps == 0: 1420 | pipeline = DiffusionPipeline.from_pretrained( 1421 | args.pretrained_model_name_or_path, 1422 | tokenizer=tokenizer, 1423 | text_encoder=accelerator.unwrap_model(text_encoder), 1424 | unet=accelerator.unwrap_model(unet), 1425 | revision=args.revision, 1426 | variant=args.variant, 1427 | torch_dtype=weight_dtype, 1428 | vae=accelerator.unwrap_model(vae), 1429 | safety_checker=None, 1430 | ) 1431 | pipeline = pipeline.to(accelerator.device) 1432 | pipeline.set_progress_bar_config(disable=True) 1433 | if args.enable_xformers_memory_efficient_attention: 1434 | pipeline.enable_xformers_memory_efficient_attention() 1435 | 1436 | with open(args.validation_prompts_path,'r',encoding='utf-8') as f: 1437 | prompts = json.load(f)[:25] 1438 | 1439 | score = get_score(pipeline, clip_model, clip_processor, prompts, args.validation_batch_size, f'{args.validation_output_dir}/{global_step}') 1440 | print(f'{global_step}: {score}') 1441 | 1442 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 1443 | progress_bar.set_postfix(**logs) 1444 | accelerator.log(logs, step=global_step) 1445 | 1446 | if global_step >= args.max_train_steps: 1447 | break 1448 | 1449 | # Create the pipeline using the trained modules and save it. 1450 | accelerator.wait_for_everyone() 1451 | if accelerator.is_main_process: 1452 | pipeline_args = {} 1453 | 1454 | if text_encoder is not None: 1455 | pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) 1456 | 1457 | if args.skip_save_text_encoder: 1458 | pipeline_args["text_encoder"] = None 1459 | 1460 | pipeline = DiffusionPipeline.from_pretrained( 1461 | args.pretrained_model_name_or_path, 1462 | unet=accelerator.unwrap_model(unet), 1463 | revision=args.revision, 1464 | variant=args.variant, 1465 | **pipeline_args, 1466 | ) 1467 | 1468 | # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it 1469 | scheduler_args = {} 1470 | 1471 | if "variance_type" in pipeline.scheduler.config: 1472 | variance_type = pipeline.scheduler.config.variance_type 1473 | 1474 | if variance_type in ["learned", "learned_range"]: 1475 | variance_type = "fixed_small" 1476 | 1477 | scheduler_args["variance_type"] = variance_type 1478 | 1479 | pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args) 1480 | 1481 | pipeline.save_pretrained(args.output_dir) 1482 | 1483 | if args.push_to_hub: 1484 | save_model_card( 1485 | repo_id, 1486 | images=images, 1487 | base_model=args.pretrained_model_name_or_path, 1488 | train_text_encoder=args.train_text_encoder, 1489 | prompt=args.instance_prompt, 1490 | repo_folder=args.output_dir, 1491 | pipeline=pipeline, 1492 | ) 1493 | upload_folder( 1494 | repo_id=repo_id, 1495 | folder_path=args.output_dir, 1496 | commit_message="End of training", 1497 | ignore_patterns=["step_*", "epoch_*"], 1498 | ) 1499 | 1500 | accelerator.end_training() 1501 | 1502 | 1503 | if __name__ == "__main__": 1504 | args = parse_args() 1505 | main(args) 1506 | --------------------------------------------------------------------------------