├── LICENSE
├── code
├── Visual Text Sampling
│ ├── text2img_sd.py
│ ├── text2img_sdxl.py
│ ├── sample_prompts_gpt.py
│ └── sample_prompts_llama.py
├── others
│ ├── blip_caption.py
│ ├── train_text_to_image.py
│ ├── train_text_to_image_sdxl.py
│ └── train_dreambooth.py
├── VLEU Calculation
│ ├── cal_vleu_openclip.py
│ └── cal_vleu_clip.py
└── Text-Image Scoring
│ └── clip_score.py
├── README.md
├── .gitignore
└── data
└── Constrained Subject
└── teddy_bear.json
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 nameless
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/code/Visual Text Sampling/text2img_sd.py:
--------------------------------------------------------------------------------
1 | from diffusers import StableDiffusionPipeline
2 | import argparse
3 | import torch
4 | import json
5 | import os
6 |
7 |
8 | if __name__ == '__main__':
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument('--model', type=str)
11 | parser.add_argument('--prompt', type=str)
12 | parser.add_argument('--prompt_json_path', type=str)
13 | parser.add_argument('--output_dir', type=str)
14 | parser.add_argument('--num', type=int)
15 | parser.add_argument('--batch_size', type=int, default=1)
16 | args = parser.parse_args()
17 |
18 | pipe = StableDiffusionPipeline.from_pretrained(args.model, torch_dtype=torch.float16, safety_checker=None)
19 | pipe = pipe.to("cuda")
20 |
21 | if args.prompt_json_path:
22 | prompts = json.load(open(args.prompt_json_path,'r',encoding='utf-8'))
23 |
24 | os.makedirs(args.output_dir, exist_ok=True)
25 |
26 | for i in range(0, args.num, args.batch_size):
27 | size = min(args.batch_size, args.num - i)
28 | if args.prompt_json_path:
29 | images = pipe(prompts[i:i+size]).images
30 | else:
31 | images = pipe([args.prompt] * size).images
32 | for j, image in enumerate(images):
33 | image.save(f'{args.output_dir}/{i+j}.jpg')
34 |
--------------------------------------------------------------------------------
/code/Visual Text Sampling/text2img_sdxl.py:
--------------------------------------------------------------------------------
1 | from diffusers import DiffusionPipeline
2 | import argparse
3 | import torch
4 | import json, os
5 |
6 |
7 | if __name__ == '__main__':
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument('--model', type=str)
10 | parser.add_argument('--prompt', type=str)
11 | parser.add_argument('--prompt_json_path', type=str)
12 | parser.add_argument('--output_dir', type=str)
13 | parser.add_argument('--start', type=int, default=0)
14 | parser.add_argument('--num', type=int)
15 | parser.add_argument('--batch_size', type=int, default=1)
16 | args = parser.parse_args()
17 |
18 | pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=torch.float16, safety_checker=None)
19 | pipe = pipe.to("cuda")
20 |
21 | if args.prompt_json_path:
22 | prompts = json.load(open(args.prompt_json_path,'r',encoding='utf-8'))
23 |
24 | os.makedirs(args.output_dir, exist_ok=True)
25 |
26 | end = args.start + args.num
27 | for i in range(args.start, end, args.batch_size):
28 | size = min(args.batch_size, end - i)
29 | if args.prompt_json_path:
30 | images = pipe(prompts[i:i+size]).images
31 | else:
32 | images = pipe([args.prompt] * size).images
33 | for j, image in enumerate(images):
34 | image.save(f'{args.output_dir}/{i+j}.jpg')
35 |
--------------------------------------------------------------------------------
/code/others/blip_caption.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import random
4 | import fnmatch
5 | import argparse
6 | import torch
7 | from PIL import Image, ImageOps
8 | from transformers import BlipProcessor, BlipForConditionalGeneration
9 | import threading
10 |
11 |
12 | def blip_caption(model_path, image_paths, device):
13 | # Load the model and processor
14 | processor = BlipProcessor.from_pretrained(model_path)
15 | model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16).to(device)
16 |
17 | for img_path in image_paths:
18 | if os.path.exists(img_path + '.txt'):
19 | continue
20 | # Open and process the image
21 | raw_image = ImageOps.exif_transpose(Image.open(img_path).convert('RGB'))
22 |
23 | # Conditional image captioning
24 | inputs = processor(raw_image, text='this is', return_tensors="pt").to(device, torch.float16)
25 |
26 | out = model.generate(**inputs)
27 | caption = processor.decode(out[0], skip_special_tokens=True)
28 |
29 | caption = re.sub(r'^this is\s*','',caption).capitalize()
30 | print(caption)
31 |
32 | # Write the caption to a .txt file
33 | with open(img_path + '.txt', 'w') as f:
34 | f.write(caption)
35 |
36 | def random_partition(input_list, num_partitions):
37 | random.shuffle(input_list)
38 | list_length = len(input_list)
39 | partition_size = list_length // num_partitions
40 | random_partitioned_lists = [input_list[i * partition_size: (i + 1) * partition_size] for i in range(num_partitions)]
41 | return random_partitioned_lists
42 |
43 | def find_all_images(root_dir):
44 | img_files = []
45 | for foldername, subfolders, filenames in os.walk(root_dir):
46 | for extension in ['*.jpg', '*.jpeg', '*.png']:
47 | for filename in fnmatch.filter(filenames, extension):
48 | file_path = os.path.join(foldername, filename)
49 | img_files.append(file_path)
50 | return img_files
51 |
52 |
53 | if __name__ == '__main__':
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument('model_path', type=str, default='Salesforce/blip-image-captioning-large')
56 | parser.add_argument('data_dir', type=str)
57 | args = parser.parse_args()
58 |
59 | image_paths = find_all_images(args.data_dir)
60 | threads = []
61 | random.shuffle(image_paths)
62 | threads.append(threading.Thread(target=blip_caption, args=(args.model_path, image_paths, 'cuda')))
63 |
64 | for thread in threads:
65 | thread.start()
66 | for thread in threads:
67 | thread.join()
68 |
--------------------------------------------------------------------------------
/code/VLEU Calculation/cal_vleu_openclip.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | from PIL import Image
8 | from tqdm import tqdm
9 | import open_clip
10 |
11 |
12 | device = 'cuda'
13 |
14 | def calculate_vleu(image_dir, prompts, temperature=0.01):
15 | with torch.no_grad():
16 | text_embs = []
17 | for prompt in tqdm(prompts):
18 | inputs = tokenizer([prompt]).to(device)
19 | outputs = model.encode_text(inputs)
20 | outputs /= outputs.norm(dim=-1, keepdim=True)
21 | text_embs.append(outputs)
22 |
23 | img_embs = []
24 | img_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]
25 | for img_path in tqdm(img_paths):
26 | image = Image.open(img_path)
27 | inputs = preprocess(image).unsqueeze(0).to(device)
28 | outputs = model.encode_image(inputs)
29 | outputs /= outputs.norm(dim=-1, keepdim=True)
30 | img_embs.append(outputs)
31 |
32 | prob_matrix = []
33 | for i in range(len(img_embs)):
34 | cosine_sim = []
35 | for j in range(len(text_embs)):
36 | cosine_sim.append(img_embs[i] @ text_embs[j].T)
37 | prob = F.softmax(torch.tensor(cosine_sim) / temperature, dim=0)
38 | prob_matrix.append(prob)
39 |
40 | prob_matrix = torch.stack(prob_matrix)
41 |
42 | # marginal distribution for text embeddings
43 | text_emb_marginal_distribution = prob_matrix.sum(axis=0) / prob_matrix.shape[0]
44 |
45 | # KL divergence for each image
46 | image_kl_divergences = []
47 | for i in range(prob_matrix.shape[0]):
48 | kl_divergence = (prob_matrix[i, :] * torch.log(prob_matrix[i, :] / text_emb_marginal_distribution)).sum().item()
49 | image_kl_divergences.append(kl_divergence)
50 |
51 | vleu_score = np.exp(sum(image_kl_divergences) / prob_matrix.shape[0])
52 | return vleu_score
53 |
54 |
55 | if __name__ == '__main__':
56 | parser = argparse.ArgumentParser()
57 | parser.add_argument('--model', type=str)
58 | parser.add_argument('--pretrained', type=str)
59 | parser.add_argument('--prompt_json_path', type=str)
60 | parser.add_argument('--image_dir', type=str)
61 | args = parser.parse_args()
62 |
63 | model, _, preprocess = open_clip.create_model_and_transforms(args.model, pretrained=args.pretrained)
64 | tokenizer = open_clip.get_tokenizer(args.model)
65 | model = model.to(device)
66 |
67 | with open(args.prompt_json_path, 'r', encoding='utf-8') as f:
68 | prompts = json.load(f)
69 |
70 | score = calculate_vleu(args.image_dir, prompts, 0.01)
71 | print(score)
72 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VLEU: a Method for Automatic Evaluation for Generalizability of Text-to-Image Models
2 |
3 | This repository contains the code and data for the paper "VLEU: a Method for Automatic Evaluation for Generalizability of Text-to-Image Models" presented at EMNLP 2024.
4 |
5 | ## Files
6 |
7 | - `code/`: Contains the Python scripts and notebooks used for the project.
8 | - `data/`: Contains the datasets used for training and evaluation.
9 |
10 | ## Quick Start
11 |
12 | ### Sampling Prompts
13 |
14 | #### 1. GPT
15 |
16 | `code/Visual Text Sampling/sample_prompts_gpt.py`
17 | ```sh
18 | python sample_prompts_gpt.py \
19 | --api_key YOUR_KEY \
20 | --n_prompts 1000 \
21 | --output prompts.json \
22 | --n_threads 50 \
23 | --step 30 \
24 | --key_word dog
25 | ```
26 |
27 | #### 2. LLaMA
28 |
29 | `code/Visual Text Sampling/sample_prompts_llama.py`
30 | ```sh
31 | python sample_prompts_llama.py \
32 | --model meta-llama/Llama-2-13b-chat-hf \
33 | --n_prompts 1000 \
34 | --output prompts.json \
35 | --num_return_sequences 2
36 | ```
37 |
38 | ### T2I Generation
39 |
40 | #### 1. Stable Diffusion
41 |
42 | `code/Visual Text Sampling/text2img_sd.py`
43 | ```sh
44 | python text2img_sd.py \
45 | --model stabilityai/stable-diffusion-2-1 \
46 | --prompt_json_path prompts.json \
47 | --output_dir image_output \
48 | --num 1000 \
49 | --batch_size 4
50 | ```
51 |
52 | #### 2. Stable Diffusion XL
53 |
54 | `code/Visual Text Sampling/text2img_sdxl.py`
55 | ```sh
56 | python text2img_sdxl.py \
57 | --model stabilityai/stable-diffusion-xl-base-1.0 \
58 | --prompt_json_path prompts.json \
59 | --output_dir image_output \
60 | --num 1000 \
61 | --batch_size 4
62 | ```
63 |
64 | ### VLEU Calculation
65 |
66 | #### 1. CLIP
67 |
68 | `code/VLEU Calculation/cal_vleu_clip.py`
69 | ```sh
70 | python cal_vleu_clip.py \
71 | --model openai/clip-vit-base-patch16 \
72 | --prompt_json_path prompts.json \
73 | --image_dir image_output
74 | ```
75 |
76 | #### 2. OpenCLIP
77 |
78 | `code/VLEU Calculation/cal_vleu_openclip.py`
79 | ```sh
80 | python cal_vleu_openclip.py \
81 | --model ViT-L-14 \
82 | --pretrained open_clip_pytorch_model.bin \
83 | --prompt_json_path prompts.json \
84 | --image_dir image_output
85 | ```
86 |
87 | ## Citation
88 |
89 | If you use this code or data in your research, please cite our paper:
90 |
91 | ```bibtex
92 | @misc{cao2024vleumethodautomaticevaluation,
93 | title={VLEU: a Method for Automatic Evaluation for Generalizability of Text-to-Image Models},
94 | author={Jingtao Cao and Zheng Zhang and Hongru Wang and Kam-Fai Wong},
95 | year={2024},
96 | eprint={2409.14704},
97 | archivePrefix={arXiv},
98 | primaryClass={cs.CV},
99 | url={https://arxiv.org/abs/2409.14704},
100 | }
101 | ```
--------------------------------------------------------------------------------
/code/Visual Text Sampling/sample_prompts_gpt.py:
--------------------------------------------------------------------------------
1 | import os, json
2 | import argparse
3 | from threading import Thread
4 | from langchain.schema import HumanMessage, SystemMessage, AIMessage
5 | from langchain.chat_models import ChatOpenAI
6 |
7 |
8 | def get_prompts(num, key_word=None, property=None):
9 | for i in range(0, num, args.step):
10 | if key_word and property:
11 | system_input = f'Please imagine a picture of {key_word} and describe it in one sentence, making sure to include the word "{key_word}" and words about {property}.'
12 | elif key_word:
13 | system_input = f'Please imagine a picture of {key_word} and describe it in one sentence, making sure to include the word "{key_word}".'
14 | else:
15 | system_input = 'Please imagine a random picture and describe it in one sentence.'
16 | human_input = [
17 | SystemMessage(content=system_input)
18 | ]
19 | ai_output = llm(human_input)
20 | n = 0
21 | limit = min(args.step, num-i)
22 | while True:
23 | human_input.append(AIMessage(content=ai_output.content))
24 | human_input.append(HumanMessage(content='Again'))
25 | ai_output = llm(human_input)
26 | while key_word and key_word not in ai_output.content:
27 | ai_output = llm(human_input)
28 | prompts.append(ai_output.content)
29 | print(ai_output.content)
30 | n += 1
31 | if n >= limit:
32 | break
33 |
34 |
35 | if __name__ == '__main__':
36 | parser = argparse.ArgumentParser()
37 | parser.add_argument('--api_key', type=str)
38 | parser.add_argument('--api_base', type=str, default=None)
39 | parser.add_argument('--n_prompts', type=int)
40 | parser.add_argument('--key_word', type=str, default=None)
41 | parser.add_argument('--output', type=str)
42 | parser.add_argument('--model_name', type=str, default='gpt-3.5-turbo')
43 | parser.add_argument('--temperature', type=float, default=0.3)
44 | parser.add_argument('--n_threads', type=int, default=1)
45 | parser.add_argument('--step', type=int, default=50)
46 | args = parser.parse_args()
47 | # global environment
48 | os.environ['OPENAI_API_KEY'] = args.api_key
49 | if args.api_base:
50 | os.environ['OPENAI_API_BASE'] = args.api_base
51 | # llm initialization
52 | llm = ChatOpenAI(model_name=args.model_name, temperature=args.temperature)
53 | # get prompts
54 | prompts = []
55 | threads = []
56 | step = args.n_prompts // args.n_threads
57 | for i in range(0, args.n_prompts, step):
58 | num = min(step, args.n_prompts-i)
59 | threads.append(Thread(target=get_prompts, args=(step, args.key_word)))
60 | for t in threads:
61 | t.start()
62 | for t in threads:
63 | t.join()
64 | with open(args.output,'w',encoding='utf-8') as f:
65 | json.dump(prompts, f)
66 |
--------------------------------------------------------------------------------
/code/Visual Text Sampling/sample_prompts_llama.py:
--------------------------------------------------------------------------------
1 | import json
2 | import argparse
3 | import torch
4 | import transformers
5 | from transformers import AutoTokenizer
6 | from tqdm import tqdm
7 |
8 |
9 | def get_prompt(prompt, num_return_sequences):
10 | prompt_len = len(prompt)
11 | sequences = pipeline(
12 | prompt,
13 | do_sample=True,
14 | temperature=0.3,
15 | top_k=10,
16 | num_return_sequences=num_return_sequences,
17 | eos_token_id=tokenizer.eos_token_id,
18 | max_length=2048
19 | )
20 | responses = [seq['generated_text'][prompt_len:].strip() for seq in sequences]
21 | return responses
22 |
23 |
24 | if __name__ == '__main__':
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('--model', type=str)
27 | parser.add_argument('--n_prompts', type=int)
28 | parser.add_argument('--key_word', type=str, default=None)
29 | parser.add_argument('--output', type=str)
30 | parser.add_argument('--num_return_sequences', type=int, default=3)
31 | args = parser.parse_args()
32 |
33 | tokenizer = AutoTokenizer.from_pretrained(args.model)
34 | pipeline = transformers.pipeline(
35 | "text-generation",
36 | model=args.model,
37 | torch_dtype=torch.bfloat16,
38 | device_map="auto",
39 | )
40 |
41 | system_prompt = ''
42 |
43 | if args.key_word:
44 | prompt = f'[INST] <>\n\n<>\n\nPlease imagine a random picture and describe it in one sentence. [/INST] A lone wolf stands proudly atop a snow-covered mountain peak, its piercing gaze reflecting both strength and solitude. [INST] Again [/INST]'
45 | response = 'A lone tree stands tall in a vast, snowy landscape, its branches adorned with delicate icicles glistening in the winter sunlight.'
46 | else:
47 | prompt = f'[INST] <>\n\n<>\n\nPlease imagine a random picture and describe it in one sentence:'
48 | response = 'A lone tree stands tall in a vast, snowy landscape, its branches adorned with delicate icicles glistening in the winter sunlight.'
49 |
50 | prompts = []
51 | prompt_stack = []
52 |
53 | user_msg = 'Again'
54 | prompt += f' {response} [INST] {user_msg} [/INST]'
55 | prompt_stack.append(prompt)
56 |
57 | n_pad_turn = 3
58 | n_pad = args.num_return_sequences ** (n_pad_turn + 1) - args.num_return_sequences
59 | n_prompts = args.n_prompts + n_pad
60 |
61 | progress_bar = tqdm(total=n_prompts)
62 | while len(prompts) < n_prompts:
63 | prompt = prompt_stack.pop(0)
64 | responses = get_prompt(prompt, args.num_return_sequences)
65 | prompts.extend(responses)
66 | prompt_stack.extend([f'{prompt} {resp} [INST] {user_msg} [/INST]' for resp in responses])
67 | progress_bar.update(len(responses))
68 |
69 | with open(args.output, 'w', encoding='utf-8') as f:
70 | json.dump(prompts[n_pad:n_prompts], f)
71 |
--------------------------------------------------------------------------------
/code/VLEU Calculation/cal_vleu_clip.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | from PIL import Image
8 | from tqdm import tqdm
9 | from transformers import CLIPProcessor, CLIPModel
10 |
11 |
12 | device = 'cuda'
13 |
14 | def calculate_vleu(image_dir, prompts, temperature=0.01):
15 | with torch.no_grad():
16 | text_embs = []
17 | for prompt in tqdm(prompts):
18 | inputs = processor([prompt], return_tensors='pt', truncation=True)
19 | inputs['input_ids'] = inputs['input_ids'].to(device)
20 | inputs['attention_mask'] = inputs['attention_mask'].to(device)
21 | outputs = model.get_text_features(**inputs)
22 | outputs /= outputs.norm(dim=-1, keepdim=True)
23 | text_embs.append(outputs)
24 |
25 | img_embs = []
26 | img_paths = [os.path.join(image_dir, f) for f in os.listdir(image_dir)]
27 | for img_path in tqdm(img_paths):
28 | image = Image.open(img_path)
29 | inputs = processor(images=image, return_tensors='pt')
30 | inputs['pixel_values'] = inputs['pixel_values'].to(device)
31 | outputs = model.get_image_features(**inputs)
32 | outputs /= outputs.norm(dim=-1, keepdim=True)
33 | img_embs.append(outputs)
34 |
35 | prob_matrix = []
36 | for i in range(len(img_embs)):
37 | cosine_sim = []
38 | for j in range(len(text_embs)):
39 | cosine_sim.append(img_embs[i] @ text_embs[j].T)
40 | prob = F.softmax(torch.tensor(cosine_sim) / temperature, dim=0)
41 | prob_matrix.append(prob)
42 |
43 | prob_matrix = torch.stack(prob_matrix)
44 |
45 | # marginal distribution for text embeddings
46 | text_emb_marginal_distribution = prob_matrix.sum(axis=0) / prob_matrix.shape[0]
47 |
48 | # KL divergence for each image
49 | image_kl_divergences = []
50 | for i in range(prob_matrix.shape[0]):
51 | kl_divergence = (prob_matrix[i, :] * torch.log(prob_matrix[i, :] / text_emb_marginal_distribution)).sum().item()
52 | image_kl_divergences.append(kl_divergence)
53 |
54 | vleu_score = np.exp(sum(image_kl_divergences) / prob_matrix.shape[0])
55 | return vleu_score
56 |
57 |
58 | if __name__ == '__main__':
59 | parser = argparse.ArgumentParser()
60 | parser.add_argument('--model', type=str)
61 | parser.add_argument('--prompt_json_path', type=str)
62 | parser.add_argument('--image_dir', type=str)
63 | args = parser.parse_args()
64 |
65 | model = CLIPModel.from_pretrained(args.model, local_files_only=True).to(device)
66 | processor = CLIPProcessor.from_pretrained(args.model, local_files_only=True)
67 |
68 | with open(args.prompt_json_path, 'r', encoding='utf-8') as f:
69 | prompts = json.load(f)
70 |
71 | score = calculate_vleu(args.image_dir, prompts, 0.01)
72 | print(score)
73 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/code/Text-Image Scoring/clip_score.py:
--------------------------------------------------------------------------------
1 | from diffusers import StableDiffusionPipeline, DiffusionPipeline
2 | from transformers import CLIPProcessor, CLIPModel
3 | import argparse
4 | import torch
5 | import torch.nn.functional as F
6 | import numpy as np
7 | import json
8 | import os
9 |
10 |
11 | def get_score(pipe, model, processor, prompts, batch_size=1, output_dir=None, temperature=0.01, device='cuda'):
12 | if output_dir:
13 | os.makedirs(output_dir, exist_ok=True)
14 |
15 | model = model.to(device)
16 |
17 | with torch.no_grad():
18 | text_embs = []
19 | img_embs = []
20 | for prompt in prompts:
21 | inputs = processor([prompt], return_tensors='pt', truncation=True)
22 | inputs['input_ids'] = inputs['input_ids'].to(device)
23 | inputs['attention_mask'] = inputs['attention_mask'].to(device)
24 | outputs = model.get_text_features(**inputs)
25 | outputs /= outputs.norm(dim=-1, keepdim=True)
26 | text_embs.append(outputs)
27 |
28 |
29 | for i in range(0, len(prompts), batch_size):
30 | with torch.autocast("cuda"):
31 | images = pipe(prompts[i:i+batch_size]).images
32 | for j, image in enumerate(images):
33 | if output_dir:
34 | image.save(f'{output_dir}/{i+j}.jpg')
35 | inputs = processor(images=image, return_tensors='pt')
36 | inputs['pixel_values'] = inputs['pixel_values'].to(device)
37 | outputs = model.get_image_features(**inputs)
38 | outputs /= outputs.norm(dim=-1, keepdim=True)
39 | img_embs.append(outputs)
40 |
41 | prob_matrix = []
42 | for i in range(len(img_embs)):
43 | cosine_sim = []
44 | for j in range(len(text_embs)):
45 | cosine_sim.append(img_embs[i] @ text_embs[j].T)
46 | prob = F.softmax(torch.tensor(cosine_sim) / temperature, dim=0)
47 | prob_matrix.append(prob)
48 |
49 | prob_matrix = torch.stack(prob_matrix)
50 |
51 | # marginal distribution for text embeddings
52 | text_emb_marginal_distribution = prob_matrix.sum(axis=0) / prob_matrix.shape[0]
53 |
54 | # KL divergence for each image
55 | image_kl_divergences = []
56 | for i in range(prob_matrix.shape[0]):
57 | kl_divergence = (prob_matrix[i, :] * torch.log(prob_matrix[i, :] / text_emb_marginal_distribution)).sum().item()
58 | image_kl_divergences.append(kl_divergence)
59 |
60 | vleu_score = np.exp(sum(image_kl_divergences) / prob_matrix.shape[0])
61 |
62 | model = model.to('cpu')
63 |
64 | return vleu_score
65 |
66 |
67 | if __name__ == '__main__':
68 | parser = argparse.ArgumentParser()
69 | parser.add_argument('--clip_model', type=str)
70 | parser.add_argument('--model', type=str)
71 | parser.add_argument('--prompt_json_path', type=str)
72 | parser.add_argument('--image_dir', type=str)
73 | parser.add_argument('--batch_size', type=int, default=1)
74 | parser.add_argument('--n_prompt', type=int)
75 | parser.add_argument('--start', type=int, default=0)
76 | parser.add_argument('--sdxl', default=False, action='store_true')
77 | args = parser.parse_args()
78 |
79 | if args.sdxl:
80 | pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=torch.float16, safety_checker=None)
81 | else:
82 | pipe = StableDiffusionPipeline.from_pretrained(args.model, torch_dtype=torch.float16, safety_checker=None)
83 | pipe = pipe.to("cuda")
84 |
85 | model = CLIPModel.from_pretrained(args.clip_model)
86 | processor = CLIPProcessor.from_pretrained(args.clip_model)
87 |
88 | with open(args.prompt_json_path,'r',encoding='utf-8') as f:
89 | prompts = json.load(f)
90 |
91 | final_score = get_score(pipe, model, processor, prompts[args.start:args.start+args.n_prompt], args.batch_size, args.image_dir)
92 | print(f'Score: {final_score}')
93 |
--------------------------------------------------------------------------------
/data/Constrained Subject/teddy_bear.json:
--------------------------------------------------------------------------------
1 | ["A fluffy, cream-colored teddy bear with a playful expression, wearing a polka dot bowtie.", "Sitting on a shelf, a vintage teddy bear with worn fur and a stitched smile exudes nostalgia and warmth.", "A small, fluffy teddy bear with soft, beige fur, wearing a tiny red bow around its neck, waiting to bring comfort and joy.", "Resting against a pile of colorful pillows, a worn and well-loved teddy bear with a faded bowtie and a stitched smile exudes warmth and nostalgia.", "Sitting on a shelf, a small, worn teddy bear with patches on its fur and a faded red bow around its neck exudes a timeless charm.", "A small, fluffy teddy bear with golden fur, adorned with a red bowtie, sitting on a bed, inviting endless snuggles.", "A small, fluffy teddy bear with golden fur, wearing a red bowtie and holding a heart-shaped pillow.", "A small, fluffy teddy bear with golden fur, wearing a polka dot bowtie and holding a heart-shaped pillow.", "A small, plush teddy bear with soft, golden fur, wearing a tiny blue bow around its neck, inviting warm hugs and endless snuggles.", "Resting against a pile of soft pillows, a well-loved, threadbare teddy bear with floppy arms and a worn-out bowtie exudes a timeless charm.", "In the picture, a small child holds a worn-out teddy bear tightly, finding comfort in its familiar embrace.", "A small, fluffy teddy bear with golden fur, wearing a polka dot bowtie and holding a heart-shaped pillow.", "A fluffy, cream-colored teddy bear with velvety paws, sparkling button eyes, and a sweet embroidered smile.", "A small, fluffy teddy bear with golden fur, adorned with a satin ribbon around its neck, sitting peacefully on a shelf.", "A small, fluffy teddy bear with soft, light brown fur, wearing a tiny red bowtie and sitting on a plaid cushion, ready to be loved and cherished.", "Resting against a pile of soft pillows, a huggable teddy bear with velvety brown fur, a shiny red bowtie, and a friendly expression invites endless cuddles and comforting companionship.", "A cuddly, brown teddy bear with button eyes and a soft, velvety texture eagerly awaits a loving embrace.", "A small, fluffy teddy bear with golden fur, wearing a red bowtie and holding a heart-shaped pillow.", "A cuddly, cream-colored teddy bear with a satin bowtie, patiently sitting on a shelf, ready to be cherished.", "A small, adorable teddy bear with light brown fur, button eyes, and a tiny red bowtie.", "Nestled among a pile of colorful toys, a worn and well-loved teddy bear with button eyes and a faded bow exudes comfort and nostalgia.", "A small, fluffy teddy bear with soft, golden fur, wearing a tiny blue bow around its neck, eagerly awaiting a warm embrace.", "A small, fluffy teddy bear with golden fur, wearing a dainty bowtie, sitting on a shelf surrounded by other cherished toys.", "A small, fluffy teddy bear with light brown fur, adorable button eyes, and a sweet smile, eagerly awaiting a loving hug.", "A small, fluffy teddy bear with soft, caramel-colored fur, wearing a dainty bowtie and sitting on a vintage rocking chair.", "Resting against a pile of soft pillows, a huggable teddy bear with velvety brown fur and a playful bowtie eagerly awaits its next adventure.", "A small, fluffy teddy bear with soft, golden fur, wearing a tiny blue bow around its neck, inviting endless hugs and snuggles.", "In this adorable picture, a small child clutches a well-loved teddy bear tightly, finding solace and companionship in its comforting presence.", "A fluffy, cream-colored teddy bear with velvety paws, sparkling black eyes, and a sweet, embroidered smile.", "Resting against a pile of soft pillows, a huggable teddy bear with chocolate-brown fur and a friendly smile eagerly awaited its next cuddle.", "A small, fluffy teddy bear with golden fur, wearing a dainty bowtie and clutching a tiny bouquet of flowers.", "Resting against a pile of colorful pillows, a huggable teddy bear with soft, caramel fur and a velvety nose eagerly awaits its next adventure.", "Resting against a pile of soft pillows, a huggable teddy bear with golden fur and a red bowtie sat, exuding comfort and companionship.", "Resting on a child's bed, a worn-out teddy bear with faded fur and a patchwork heart stitched on its chest brings comfort and nostalgia.", "Resting against a pile of pastel pillows, a vintage teddy bear with worn fur and a faded bow exudes timeless charm and cherished memories.", "A fluffy, cream-colored teddy bear with button eyes, a stitched nose, and outstretched arms, ready for a warm embrace.", "A cuddly, blue teddy bear with a satin bow, patiently waiting for a warm hug.", "A small, fluffy teddy bear with golden fur, wearing a red bowtie and sitting on a pile of colorful cushions.", "Nestled on a shelf, a brand new teddy bear with fluffy brown fur and a satin bow eagerly awaits its first cuddle.", "A fluffy, cream-colored teddy bear with velvety paws, wearing a dainty pink bow around its neck, sitting on a pile of colorful, confetti-filled balloons.", "A small, plush teddy bear with soft, golden fur, hugging a tiny bouquet of colorful flowers.", "A small, adorable teddy bear with light brown fur, button eyes, and a tiny stitched nose, patiently waiting for a warm embrace.", "Resting against a stack of storybooks, a well-loved teddy bear with soft, chocolate-brown fur and a velvety nose exudes warmth and companionship.", "In the picture, a cuddly, light brown teddy bear with button eyes and a satin bowtie sits on a grassy meadow, surrounded by colorful wildflowers, exuding warmth and comfort.", "In a sunlit meadow, a worn and well-loved teddy bear, adorned with a faded bow and stitched smile, sits atop a picnic blanket, patiently waiting for a child's warm embrace.", "A fluffy, adorable teddy bear with a bowtie and a mischievous twinkle in its button eyes eagerly awaits a new friend to share endless hugs and adventures with.", "A small, fluffy teddy bear with soft, caramel-colored fur, wearing a tiny bowtie and eagerly waiting for a warm embrace.", "A huggable teddy bear with velvety, chocolate-brown fur, adorable button eyes, and a heart-shaped nose, inviting warm embraces and endless snuggles.", "Resting on a sunlit windowsill, a worn and well-loved teddy bear with matted fur and a faded bow sits patiently, a silent guardian of cherished childhood memories.", "A well-loved, worn-out teddy bear with faded brown fur, missing an eye, and patched up with colorful stitches.", "A fluffy, cream-colored teddy bear with velvety paws and a twinkling button nose, sitting patiently with arms outstretched for a warm embrace.", "A cuddly, light brown teddy bear with button eyes and a soft, velvety texture sits patiently on a child's bed, eagerly waiting for bedtime snuggles and sweet dreams.", "A vintage teddy bear with worn, caramel-colored fur, stitched paws, and a well-loved expression on its face.", "In the picture, a group of teddy bears of various colors and sizes gather for a tea party, their adorable expressions filled with joy and playfulness.", "A vintage, well-loved teddy bear with worn-out patches, stitched together with love, sitting proudly on a rocking chair, reminiscing about countless childhood adventures.", "Nestled among a sea of stuffed animals, a plush teddy bear with soft, chocolate-brown fur and a friendly expression invites endless hugs and companionship.", "Perched on a wooden rocking chair, a fluffy teddy bear with floppy ears and a satin ribbon around its neck patiently awaits its next adventure with a gleeful sparkle in its button eyes.", "A well-loved teddy bear, worn with time, featuring patches of different fabrics and a heart-shaped button on its chest, holding cherished memories within its tattered seams.", "A well-loved, worn-out teddy bear with matted fur, missing an eye, and a stitched-up arm, but still cherished for all the memories it holds.", "A small, fluffy teddy bear with soft, light brown fur, adorable button eyes, and a sweet, embroidered smile, sitting patiently with its paws outstretched for a warm embrace.", "A well-loved teddy bear with worn fur, missing an eye, and a stitched-up paw, exuding a timeless charm and cherished memories.", "A well-loved teddy bear with faded golden fur, missing an eye and a few stitches, proudly displaying the wear and tear of countless adventures and comforting embraces.", "Nestled among a sea of vibrant wildflowers, a small, velvety teddy bear with a bowtie and a twinkle in its eyes brings comfort and joy to all who lay eyes upon it.", "Perched on a child's bed, a plush, rosy-cheeked teddy bear with twinkling black eyes and velvety fur patiently waits for bedtime stories and gentle hugs.", "A small, adorable teddy bear with golden fur and a red bowtie sits on a shelf, patiently waiting to be cuddled and cherished.", "A well-loved teddy bear, with worn-out fur and a few missing stitches, but still radiating warmth and comfort.", "A well-loved teddy bear, worn from years of affection, with button eyes and a patchwork of colorful patches, holding onto cherished memories.", "A huggable teddy bear with velvety, chocolate-brown fur, holding a tiny bouquet of wildflowers in its paws.", "Perched on a sunlit windowsill, a plush teddy bear with velvety fur and a heart-shaped nose gazes out with a sense of warmth and companionship.", "A well-loved teddy bear with worn-out patches, a faded smile, and a heart-shaped patch on its chest that reads \"Teddy Bear Hugs Forever\".", "Perched on a shelf, a well-loved teddy bear with worn fur and a faded blue ribbon around its neck stood as a cherished reminder of childhood innocence.", "A small, adorable teddy bear with soft, golden fur, sitting on a bed with a mischievous twinkle in its button eyes.", "A vintage, well-loved teddy bear with faded brown fur, missing an eye, and patched with colorful fabric, sitting patiently on a worn-out armchair.", "A small, plush teddy bear with soft, golden fur, a tiny red bowtie, and a mischievous twinkle in its button eyes.", "A worn-out, well-loved teddy bear with patches sewn onto its faded brown fur, sitting proudly on a rocking chair, a testament to years of comforting hugs.", "Nestled among a sea of colorful toys, a cherished teddy bear with plush, golden fur and a heartwarming smile brings joy and comfort to its young owner.", "A small, plush teddy bear with soft, golden fur, wearing a tiny blue sweater and holding a miniature picnic basket.", "A vintage, well-loved teddy bear with worn-out fur, missing an eye, and a patched-up paw, carrying years of cherished memories.", "Perched on a child's bed, a cherished teddy bear with golden fur and a playful twinkle in its button eyes patiently awaited bedtime stories and sweet dreams.", "A well-loved teddy bear with worn-out fur, missing an eye, and patched up with colorful stitches, radiating a timeless charm.", "A small, adorable teddy bear with soft, golden fur, sitting on a bed with a mischievous twinkle in its button eyes.", "Sitting on a shelf, a small, plush teddy bear with a red bowtie and a mischievous twinkle in its button eyes eagerly awaits a new adventure.", "Sitting on a patchwork quilt, a vintage teddy bear with golden fur and a well-loved appearance exudes warmth and nostalgia.", "Perched on a shelf, a cherished teddy bear with golden fur, a stitched nose, and outstretched arms exudes timeless charm, ready to be a steadfast friend through all of life's adventures.", "Perched on a child's bed, a fluffy teddy bear with a mischievous twinkle in its button eyes eagerly awaits bedtime stories and snuggles.", "A well-loved, worn-out teddy bear with faded golden fur, missing one eye, and a patched-up paw, holding onto cherished memories and endless love.", "Perched on a sunlit windowsill, a small, rosy-cheeked teddy bear with velvety fur and a satin bow eagerly awaits its next adventure with a gleam of anticipation in its glassy eyes.", "Perched on a sunlit windowsill, a cherished teddy bear with golden fur and a satin bow around its neck gazes out at the world with a sense of timeless love and warmth.", "A small, adorable teddy bear with soft, golden fur, a tiny red bow around its neck, and a playful expression on its face.", "Perched on a windowsill, a charming teddy bear with golden fur and a satin bow around its neck gazes out at the world with a sense of wonder and innocence.", "A well-loved teddy bear, with worn-out patches and a slightly faded fur, holding a heart-shaped pillow that says \"I love you.\"", "In the picture, a small, adorable teddy bear with soft, caramel-colored fur and a heart-shaped nose sits on a child's bed, patiently waiting for bedtime cuddles.", "A small, fuzzy teddy bear with chocolate-brown fur, button eyes, and a stitched smile, sitting on a child's bed surrounded by a collection of other stuffed animals.", "A well-loved teddy bear, worn from years of snuggles, with button eyes and a stitched smile, radiating warmth and nostalgia.", "A small, well-loved teddy bear with faded fur and a slightly crooked smile rests peacefully on a worn-out armchair, embodying years of cherished memories and comforting hugs.", "A beautifully handcrafted teddy bear, adorned with a satin bow and a gentle expression, sits patiently on a shelf, ready to be cherished by its lucky owner.", "Perched atop a child's bed, a plush teddy bear with soft, chocolate-brown fur and a playful twinkle in its button eyes eagerly awaits bedtime cuddles.", "A huggable teddy bear, made of plush, caramel-colored fur, with a velvety nose, twinkling black eyes, and a gentle expression that radiates comfort and love.", "Perched on a child's bed, a well-loved teddy bear with worn fur and a heart-shaped patch on its paw exudes a timeless charm and holds cherished memories within its stitched embrace.", "A small, adorable teddy bear with soft, golden fur, wearing a tiny blue bowtie and clutching a miniature book in its paws."]
--------------------------------------------------------------------------------
/code/others/train_text_to_image.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 |
16 | import argparse
17 | import logging
18 | import math
19 | import os
20 | import json
21 | import random
22 | import shutil
23 | from pathlib import Path
24 |
25 | import accelerate
26 | import datasets
27 | import numpy as np
28 | import torch
29 | import torch.nn.functional as F
30 | import torch.utils.checkpoint
31 | import transformers
32 | from accelerate import Accelerator
33 | from accelerate.logging import get_logger
34 | from accelerate.state import AcceleratorState
35 | from accelerate.utils import ProjectConfiguration, set_seed
36 | from datasets import load_dataset
37 | from huggingface_hub import create_repo, upload_folder
38 | from packaging import version
39 | from torchvision import transforms
40 | from tqdm.auto import tqdm
41 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPProcessor, CLIPModel
42 | from transformers.utils import ContextManagers
43 |
44 | import diffusers
45 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
46 | from diffusers.optimization import get_scheduler
47 | from diffusers.training_utils import EMAModel, compute_snr
48 | from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
49 | from diffusers.utils.import_utils import is_xformers_available
50 |
51 | from clip_score import get_score
52 |
53 |
54 | if is_wandb_available():
55 | import wandb
56 |
57 |
58 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59 | check_min_version("0.25.0.dev0")
60 |
61 | logger = get_logger(__name__, log_level="INFO")
62 |
63 | DATASET_NAME_MAPPING = {
64 | "lambdalabs/pokemon-blip-captions": ("image", "text"),
65 | }
66 |
67 |
68 | def save_model_card(
69 | args,
70 | repo_id: str,
71 | images=None,
72 | repo_folder=None,
73 | ):
74 | img_str = ""
75 | if len(images) > 0:
76 | image_grid = make_image_grid(images, 1, len(args.validation_prompts))
77 | image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
78 | img_str += "\n"
79 |
80 | yaml = f"""
81 | ---
82 | license: creativeml-openrail-m
83 | base_model: {args.pretrained_model_name_or_path}
84 | datasets:
85 | - {args.dataset_name}
86 | tags:
87 | - stable-diffusion
88 | - stable-diffusion-diffusers
89 | - text-to-image
90 | - diffusers
91 | inference: true
92 | ---
93 | """
94 | model_card = f"""
95 | # Text-to-image finetuning - {repo_id}
96 |
97 | This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
98 | {img_str}
99 |
100 | ## Pipeline usage
101 |
102 | You can use the pipeline like so:
103 |
104 | ```python
105 | from diffusers import DiffusionPipeline
106 | import torch
107 |
108 | pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
109 | prompt = "{args.validation_prompts[0]}"
110 | image = pipeline(prompt).images[0]
111 | image.save("my_image.png")
112 | ```
113 |
114 | ## Training info
115 |
116 | These are the key hyperparameters used during training:
117 |
118 | * Epochs: {args.num_train_epochs}
119 | * Learning rate: {args.learning_rate}
120 | * Batch size: {args.train_batch_size}
121 | * Gradient accumulation steps: {args.gradient_accumulation_steps}
122 | * Image resolution: {args.resolution}
123 | * Mixed-precision: {args.mixed_precision}
124 |
125 | """
126 | wandb_info = ""
127 | if is_wandb_available():
128 | wandb_run_url = None
129 | if wandb.run is not None:
130 | wandb_run_url = wandb.run.url
131 |
132 | if wandb_run_url is not None:
133 | wandb_info = f"""
134 | More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
135 | """
136 |
137 | model_card += wandb_info
138 |
139 | with open(os.path.join(repo_folder, "README.md"), "w") as f:
140 | f.write(yaml + model_card)
141 |
142 |
143 | def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
144 | logger.info("Running validation... ")
145 |
146 | pipeline = StableDiffusionPipeline.from_pretrained(
147 | args.pretrained_model_name_or_path,
148 | vae=accelerator.unwrap_model(vae),
149 | text_encoder=accelerator.unwrap_model(text_encoder),
150 | tokenizer=tokenizer,
151 | unet=accelerator.unwrap_model(unet),
152 | safety_checker=None,
153 | revision=args.revision,
154 | variant=args.variant,
155 | torch_dtype=weight_dtype,
156 | )
157 | pipeline = pipeline.to(accelerator.device)
158 | pipeline.set_progress_bar_config(disable=True)
159 |
160 | if args.enable_xformers_memory_efficient_attention:
161 | pipeline.enable_xformers_memory_efficient_attention()
162 |
163 | if args.seed is None:
164 | generator = None
165 | else:
166 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
167 |
168 | images = []
169 | for i in range(len(args.validation_prompts)):
170 | with torch.autocast("cuda"):
171 | image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
172 |
173 | images.append(image)
174 |
175 | for tracker in accelerator.trackers:
176 | if tracker.name == "tensorboard":
177 | np_images = np.stack([np.asarray(img) for img in images])
178 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
179 | elif tracker.name == "wandb":
180 | tracker.log(
181 | {
182 | "validation": [
183 | wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
184 | for i, image in enumerate(images)
185 | ]
186 | }
187 | )
188 | else:
189 | logger.warn(f"image logging not implemented for {tracker.name}")
190 |
191 | del pipeline
192 | torch.cuda.empty_cache()
193 |
194 | return images
195 |
196 |
197 | def parse_args():
198 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
199 | parser.add_argument(
200 | "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
201 | )
202 | parser.add_argument(
203 | "--pretrained_model_name_or_path",
204 | type=str,
205 | default=None,
206 | required=True,
207 | help="Path to pretrained model or model identifier from huggingface.co/models.",
208 | )
209 | parser.add_argument(
210 | "--clip_model_name_or_path",
211 | type=str,
212 | default=None,
213 | required=True,
214 | help="Path to clip model or model identifier from huggingface.co/models.",
215 | )
216 | parser.add_argument(
217 | "--revision",
218 | type=str,
219 | default=None,
220 | required=False,
221 | help="Revision of pretrained model identifier from huggingface.co/models.",
222 | )
223 | parser.add_argument(
224 | "--variant",
225 | type=str,
226 | default=None,
227 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
228 | )
229 | parser.add_argument(
230 | "--dataset_name",
231 | type=str,
232 | default=None,
233 | help=(
234 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
235 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
236 | " or to a folder containing files that 🤗 Datasets can understand."
237 | ),
238 | )
239 | parser.add_argument(
240 | "--dataset_config_name",
241 | type=str,
242 | default=None,
243 | help="The config of the Dataset, leave as None if there's only one config.",
244 | )
245 | parser.add_argument(
246 | "--train_data_dir",
247 | type=str,
248 | default=None,
249 | help=(
250 | "A folder containing the training data. Folder contents must follow the structure described in"
251 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
252 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
253 | ),
254 | )
255 | parser.add_argument(
256 | "--image_column", type=str, default="image", help="The column of the dataset containing an image."
257 | )
258 | parser.add_argument(
259 | "--caption_column",
260 | type=str,
261 | default="text",
262 | help="The column of the dataset containing a caption or a list of captions.",
263 | )
264 | parser.add_argument(
265 | "--max_train_samples",
266 | type=int,
267 | default=None,
268 | help=(
269 | "For debugging purposes or quicker training, truncate the number of training examples to this "
270 | "value if set."
271 | ),
272 | )
273 | parser.add_argument(
274 | "--validation_prompts",
275 | type=str,
276 | default=None,
277 | nargs="+",
278 | help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
279 | )
280 | parser.add_argument(
281 | "--validation_prompts_path",
282 | type=str,
283 | default=None
284 | )
285 | parser.add_argument(
286 | "--output_dir",
287 | type=str,
288 | default="sd-model-finetuned",
289 | help="The output directory where the model predictions and checkpoints will be written.",
290 | )
291 | parser.add_argument(
292 | "--cache_dir",
293 | type=str,
294 | default=None,
295 | help="The directory where the downloaded models and datasets will be stored.",
296 | )
297 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
298 | parser.add_argument(
299 | "--resolution",
300 | type=int,
301 | default=512,
302 | help=(
303 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
304 | " resolution"
305 | ),
306 | )
307 | parser.add_argument(
308 | "--center_crop",
309 | default=False,
310 | action="store_true",
311 | help=(
312 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
313 | " cropped. The images will be resized to the resolution first before cropping."
314 | ),
315 | )
316 | parser.add_argument(
317 | "--random_flip",
318 | action="store_true",
319 | help="whether to randomly flip images horizontally",
320 | )
321 | parser.add_argument(
322 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
323 | )
324 | parser.add_argument("--num_train_epochs", type=int, default=100)
325 | parser.add_argument(
326 | "--max_train_steps",
327 | type=int,
328 | default=None,
329 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
330 | )
331 | parser.add_argument(
332 | "--gradient_accumulation_steps",
333 | type=int,
334 | default=1,
335 | help="Number of updates steps to accumulate before performing a backward/update pass.",
336 | )
337 | parser.add_argument(
338 | "--gradient_checkpointing",
339 | action="store_true",
340 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
341 | )
342 | parser.add_argument(
343 | "--learning_rate",
344 | type=float,
345 | default=1e-4,
346 | help="Initial learning rate (after the potential warmup period) to use.",
347 | )
348 | parser.add_argument(
349 | "--scale_lr",
350 | action="store_true",
351 | default=False,
352 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
353 | )
354 | parser.add_argument(
355 | "--lr_scheduler",
356 | type=str,
357 | default="constant",
358 | help=(
359 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
360 | ' "constant", "constant_with_warmup"]'
361 | ),
362 | )
363 | parser.add_argument(
364 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
365 | )
366 | parser.add_argument(
367 | "--snr_gamma",
368 | type=float,
369 | default=None,
370 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
371 | "More details here: https://arxiv.org/abs/2303.09556.",
372 | )
373 | parser.add_argument(
374 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
375 | )
376 | parser.add_argument(
377 | "--allow_tf32",
378 | action="store_true",
379 | help=(
380 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
381 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
382 | ),
383 | )
384 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
385 | parser.add_argument(
386 | "--non_ema_revision",
387 | type=str,
388 | default=None,
389 | required=False,
390 | help=(
391 | "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
392 | " remote repository specified with --pretrained_model_name_or_path."
393 | ),
394 | )
395 | parser.add_argument(
396 | "--dataloader_num_workers",
397 | type=int,
398 | default=0,
399 | help=(
400 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
401 | ),
402 | )
403 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
404 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
405 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
406 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
407 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
408 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
409 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
410 | parser.add_argument(
411 | "--prediction_type",
412 | type=str,
413 | default=None,
414 | help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
415 | )
416 | parser.add_argument(
417 | "--hub_model_id",
418 | type=str,
419 | default=None,
420 | help="The name of the repository to keep in sync with the local `output_dir`.",
421 | )
422 | parser.add_argument(
423 | "--logging_dir",
424 | type=str,
425 | default="logs",
426 | help=(
427 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
428 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
429 | ),
430 | )
431 | parser.add_argument(
432 | "--mixed_precision",
433 | type=str,
434 | default=None,
435 | choices=["no", "fp16", "bf16"],
436 | help=(
437 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
438 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
439 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
440 | ),
441 | )
442 | parser.add_argument(
443 | "--report_to",
444 | type=str,
445 | default="tensorboard",
446 | help=(
447 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
448 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
449 | ),
450 | )
451 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
452 | parser.add_argument(
453 | "--checkpointing_steps",
454 | type=int,
455 | default=500,
456 | help=(
457 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
458 | " training using `--resume_from_checkpoint`."
459 | ),
460 | )
461 | parser.add_argument(
462 | "--checkpoints_total_limit",
463 | type=int,
464 | default=None,
465 | help=("Max number of checkpoints to store."),
466 | )
467 | parser.add_argument(
468 | "--resume_from_checkpoint",
469 | type=str,
470 | default=None,
471 | help=(
472 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
473 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
474 | ),
475 | )
476 | parser.add_argument(
477 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
478 | )
479 | parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
480 | parser.add_argument(
481 | "--validation_epochs",
482 | type=int,
483 | default=5,
484 | help="Run validation every X epochs.",
485 | )
486 | parser.add_argument(
487 | "--validation_steps",
488 | type=int,
489 | default=50,
490 | help="Run validation every X steps.",
491 | )
492 | parser.add_argument(
493 | "--validation_output_dir",
494 | type=str,
495 | default=None
496 | )
497 | parser.add_argument(
498 | "--validation_batch_size",
499 | type=int,
500 | default=1
501 | )
502 | parser.add_argument(
503 | "--tracker_project_name",
504 | type=str,
505 | default="text2image-fine-tune",
506 | help=(
507 | "The `project_name` argument passed to Accelerator.init_trackers for"
508 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
509 | ),
510 | )
511 |
512 | args = parser.parse_args()
513 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
514 | if env_local_rank != -1 and env_local_rank != args.local_rank:
515 | args.local_rank = env_local_rank
516 |
517 | # Sanity checks
518 | if args.dataset_name is None and args.train_data_dir is None:
519 | raise ValueError("Need either a dataset name or a training folder.")
520 |
521 | # default to using the same revision for the non-ema model if not specified
522 | if args.non_ema_revision is None:
523 | args.non_ema_revision = args.revision
524 |
525 | return args
526 |
527 |
528 | def main():
529 | args = parse_args()
530 |
531 | if args.non_ema_revision is not None:
532 | deprecate(
533 | "non_ema_revision!=None",
534 | "0.15.0",
535 | message=(
536 | "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
537 | " use `--variant=non_ema` instead."
538 | ),
539 | )
540 | logging_dir = os.path.join(args.output_dir, args.logging_dir)
541 |
542 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
543 |
544 | accelerator = Accelerator(
545 | gradient_accumulation_steps=args.gradient_accumulation_steps,
546 | mixed_precision=args.mixed_precision,
547 | log_with=args.report_to,
548 | project_config=accelerator_project_config,
549 | )
550 |
551 | # Make one log on every process with the configuration for debugging.
552 | logging.basicConfig(
553 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
554 | datefmt="%m/%d/%Y %H:%M:%S",
555 | level=logging.INFO,
556 | )
557 | logger.info(accelerator.state, main_process_only=False)
558 | if accelerator.is_local_main_process:
559 | datasets.utils.logging.set_verbosity_warning()
560 | transformers.utils.logging.set_verbosity_warning()
561 | diffusers.utils.logging.set_verbosity_info()
562 | else:
563 | datasets.utils.logging.set_verbosity_error()
564 | transformers.utils.logging.set_verbosity_error()
565 | diffusers.utils.logging.set_verbosity_error()
566 |
567 | # If passed along, set the training seed now.
568 | if args.seed is not None:
569 | set_seed(args.seed)
570 |
571 | # Handle the repository creation
572 | if accelerator.is_main_process:
573 | if args.output_dir is not None:
574 | os.makedirs(args.output_dir, exist_ok=True)
575 |
576 | if args.push_to_hub:
577 | repo_id = create_repo(
578 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
579 | ).repo_id
580 |
581 | # Load scheduler, tokenizer and models.
582 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
583 | tokenizer = CLIPTokenizer.from_pretrained(
584 | args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
585 | )
586 |
587 | if args.clip_model_name_or_path:
588 | clip_model = CLIPModel.from_pretrained(args.clip_model_name_or_path)
589 | clip_processor = CLIPProcessor.from_pretrained(args.clip_model_name_or_path)
590 |
591 | def deepspeed_zero_init_disabled_context_manager():
592 | """
593 | returns either a context list that includes one that will disable zero.Init or an empty context list
594 | """
595 | deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
596 | if deepspeed_plugin is None:
597 | return []
598 |
599 | return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
600 |
601 | # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
602 | # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
603 | # will try to assign the same optimizer with the same weights to all models during
604 | # `deepspeed.initialize`, which of course doesn't work.
605 | #
606 | # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
607 | # frozen models from being partitioned during `zero.Init` which gets called during
608 | # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
609 | # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
610 | with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
611 | text_encoder = CLIPTextModel.from_pretrained(
612 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
613 | )
614 | vae = AutoencoderKL.from_pretrained(
615 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
616 | )
617 |
618 | unet = UNet2DConditionModel.from_pretrained(
619 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
620 | )
621 |
622 | # Freeze vae and text_encoder and set unet to trainable
623 | vae.requires_grad_(False)
624 | text_encoder.requires_grad_(False)
625 | unet.train()
626 |
627 | # Create EMA for the unet.
628 | if args.use_ema:
629 | ema_unet = UNet2DConditionModel.from_pretrained(
630 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
631 | )
632 | ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
633 |
634 | if args.enable_xformers_memory_efficient_attention:
635 | if is_xformers_available():
636 | import xformers
637 |
638 | xformers_version = version.parse(xformers.__version__)
639 | if xformers_version == version.parse("0.0.16"):
640 | logger.warn(
641 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
642 | )
643 | unet.enable_xformers_memory_efficient_attention()
644 | else:
645 | raise ValueError("xformers is not available. Make sure it is installed correctly")
646 |
647 | # `accelerate` 0.16.0 will have better support for customized saving
648 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
649 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
650 | def save_model_hook(models, weights, output_dir):
651 | if accelerator.is_main_process:
652 | if args.use_ema:
653 | ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
654 |
655 | for i, model in enumerate(models):
656 | model.save_pretrained(os.path.join(output_dir, "unet"))
657 |
658 | # make sure to pop weight so that corresponding model is not saved again
659 | weights.pop()
660 |
661 | def load_model_hook(models, input_dir):
662 | if args.use_ema:
663 | load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
664 | ema_unet.load_state_dict(load_model.state_dict())
665 | ema_unet.to(accelerator.device)
666 | del load_model
667 |
668 | for i in range(len(models)):
669 | # pop models so that they are not loaded again
670 | model = models.pop()
671 |
672 | # load diffusers style into model
673 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
674 | model.register_to_config(**load_model.config)
675 |
676 | model.load_state_dict(load_model.state_dict())
677 | del load_model
678 |
679 | accelerator.register_save_state_pre_hook(save_model_hook)
680 | accelerator.register_load_state_pre_hook(load_model_hook)
681 |
682 | if args.gradient_checkpointing:
683 | unet.enable_gradient_checkpointing()
684 |
685 | # Enable TF32 for faster training on Ampere GPUs,
686 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
687 | if args.allow_tf32:
688 | torch.backends.cuda.matmul.allow_tf32 = True
689 |
690 | if args.scale_lr:
691 | args.learning_rate = (
692 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
693 | )
694 |
695 | # Initialize the optimizer
696 | if args.use_8bit_adam:
697 | try:
698 | import bitsandbytes as bnb
699 | except ImportError:
700 | raise ImportError(
701 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
702 | )
703 |
704 | optimizer_cls = bnb.optim.AdamW8bit
705 | else:
706 | optimizer_cls = torch.optim.AdamW
707 |
708 | optimizer = optimizer_cls(
709 | unet.parameters(),
710 | lr=args.learning_rate,
711 | betas=(args.adam_beta1, args.adam_beta2),
712 | weight_decay=args.adam_weight_decay,
713 | eps=args.adam_epsilon,
714 | )
715 |
716 | # Get the datasets: you can either provide your own training and evaluation files (see below)
717 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
718 |
719 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently
720 | # download the dataset.
721 | if args.dataset_name is not None:
722 | # Downloading and loading a dataset from the hub.
723 | dataset = load_dataset(
724 | args.dataset_name,
725 | args.dataset_config_name,
726 | cache_dir=args.cache_dir,
727 | data_dir=args.train_data_dir,
728 | )
729 | else:
730 | data_files = {}
731 | if args.train_data_dir is not None:
732 | data_files["train"] = os.path.join(args.train_data_dir, "**")
733 | dataset = load_dataset(
734 | "imagefolder",
735 | data_files=data_files,
736 | cache_dir=args.cache_dir,
737 | )
738 | # See more about loading custom images at
739 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
740 |
741 | # Preprocessing the datasets.
742 | # We need to tokenize inputs and targets.
743 | column_names = dataset["train"].column_names
744 |
745 | # 6. Get the column names for input/target.
746 | dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
747 | if args.image_column is None:
748 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
749 | else:
750 | image_column = args.image_column
751 | if image_column not in column_names:
752 | raise ValueError(
753 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
754 | )
755 | if args.caption_column is None:
756 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
757 | else:
758 | caption_column = args.caption_column
759 | if caption_column not in column_names:
760 | raise ValueError(
761 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
762 | )
763 |
764 | # Preprocessing the datasets.
765 | # We need to tokenize input captions and transform the images.
766 | def tokenize_captions(examples, is_train=True):
767 | captions = []
768 | for caption in examples[caption_column]:
769 | if isinstance(caption, str):
770 | captions.append(caption)
771 | elif isinstance(caption, (list, np.ndarray)):
772 | # take a random caption if there are multiple
773 | captions.append(random.choice(caption) if is_train else caption[0])
774 | else:
775 | raise ValueError(
776 | f"Caption column `{caption_column}` should contain either strings or lists of strings."
777 | )
778 | inputs = tokenizer(
779 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
780 | )
781 | return inputs.input_ids
782 |
783 | # Preprocessing the datasets.
784 | train_transforms = transforms.Compose(
785 | [
786 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
787 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
788 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
789 | transforms.ToTensor(),
790 | transforms.Normalize([0.5], [0.5]),
791 | ]
792 | )
793 |
794 | def preprocess_train(examples):
795 | images = [image.convert("RGB") for image in examples[image_column]]
796 | examples["pixel_values"] = [train_transforms(image) for image in images]
797 | examples["input_ids"] = tokenize_captions(examples)
798 | return examples
799 |
800 | with accelerator.main_process_first():
801 | if args.max_train_samples is not None:
802 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
803 | # Set the training transforms
804 | train_dataset = dataset["train"].with_transform(preprocess_train)
805 |
806 | def collate_fn(examples):
807 | pixel_values = torch.stack([example["pixel_values"] for example in examples])
808 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
809 | input_ids = torch.stack([example["input_ids"] for example in examples])
810 | return {"pixel_values": pixel_values, "input_ids": input_ids}
811 |
812 | # DataLoaders creation:
813 | train_dataloader = torch.utils.data.DataLoader(
814 | train_dataset,
815 | shuffle=True,
816 | collate_fn=collate_fn,
817 | batch_size=args.train_batch_size,
818 | num_workers=args.dataloader_num_workers,
819 | )
820 |
821 | # Scheduler and math around the number of training steps.
822 | overrode_max_train_steps = False
823 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
824 | if args.max_train_steps is None:
825 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
826 | overrode_max_train_steps = True
827 |
828 | lr_scheduler = get_scheduler(
829 | args.lr_scheduler,
830 | optimizer=optimizer,
831 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
832 | num_training_steps=args.max_train_steps * accelerator.num_processes,
833 | )
834 |
835 | # Prepare everything with our `accelerator`.
836 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
837 | unet, optimizer, train_dataloader, lr_scheduler
838 | )
839 |
840 | if args.use_ema:
841 | ema_unet.to(accelerator.device)
842 |
843 | # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
844 | # as these weights are only used for inference, keeping weights in full precision is not required.
845 | weight_dtype = torch.float32
846 | if accelerator.mixed_precision == "fp16":
847 | weight_dtype = torch.float16
848 | args.mixed_precision = accelerator.mixed_precision
849 | elif accelerator.mixed_precision == "bf16":
850 | weight_dtype = torch.bfloat16
851 | args.mixed_precision = accelerator.mixed_precision
852 |
853 | # Move text_encode and vae to gpu and cast to weight_dtype
854 | text_encoder.to(accelerator.device, dtype=weight_dtype)
855 | vae.to(accelerator.device, dtype=weight_dtype)
856 |
857 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
858 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
859 | if overrode_max_train_steps:
860 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
861 | # Afterwards we recalculate our number of training epochs
862 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
863 |
864 | # We need to initialize the trackers we use, and also store our configuration.
865 | # The trackers initializes automatically on the main process.
866 | if accelerator.is_main_process:
867 | tracker_config = dict(vars(args))
868 | tracker_config.pop("validation_prompts")
869 | accelerator.init_trackers(args.tracker_project_name, tracker_config)
870 |
871 | # Train!
872 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
873 |
874 | logger.info("***** Running training *****")
875 | logger.info(f" Num examples = {len(train_dataset)}")
876 | logger.info(f" Num Epochs = {args.num_train_epochs}")
877 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
878 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
879 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
880 | logger.info(f" Total optimization steps = {args.max_train_steps}")
881 | global_step = 0
882 | first_epoch = 0
883 |
884 | # Potentially load in the weights and states from a previous save
885 | if args.resume_from_checkpoint:
886 | if args.resume_from_checkpoint != "latest":
887 | path = os.path.basename(args.resume_from_checkpoint)
888 | else:
889 | # Get the most recent checkpoint
890 | dirs = os.listdir(args.output_dir)
891 | dirs = [d for d in dirs if d.startswith("checkpoint")]
892 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
893 | path = dirs[-1] if len(dirs) > 0 else None
894 |
895 | if path is None:
896 | accelerator.print(
897 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
898 | )
899 | args.resume_from_checkpoint = None
900 | initial_global_step = 0
901 | else:
902 | accelerator.print(f"Resuming from checkpoint {path}")
903 | accelerator.load_state(os.path.join(args.output_dir, path))
904 | global_step = int(path.split("-")[1])
905 |
906 | initial_global_step = global_step
907 | first_epoch = global_step // num_update_steps_per_epoch
908 |
909 | else:
910 | initial_global_step = 0
911 |
912 | progress_bar = tqdm(
913 | range(0, args.max_train_steps),
914 | initial=initial_global_step,
915 | desc="Steps",
916 | # Only show the progress bar once on each machine.
917 | disable=not accelerator.is_local_main_process,
918 | )
919 |
920 | for epoch in range(first_epoch, args.num_train_epochs):
921 | train_loss = 0.0
922 | for step, batch in enumerate(train_dataloader):
923 | with accelerator.accumulate(unet):
924 | # Convert images to latent space
925 | latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
926 | latents = latents * vae.config.scaling_factor
927 |
928 | # Sample noise that we'll add to the latents
929 | noise = torch.randn_like(latents)
930 | if args.noise_offset:
931 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise
932 | noise += args.noise_offset * torch.randn(
933 | (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
934 | )
935 | if args.input_perturbation:
936 | new_noise = noise + args.input_perturbation * torch.randn_like(noise)
937 | bsz = latents.shape[0]
938 | # Sample a random timestep for each image
939 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
940 | timesteps = timesteps.long()
941 |
942 | # Add noise to the latents according to the noise magnitude at each timestep
943 | # (this is the forward diffusion process)
944 | if args.input_perturbation:
945 | noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
946 | else:
947 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
948 |
949 | # Get the text embedding for conditioning
950 | encoder_hidden_states = text_encoder(batch["input_ids"])[0]
951 |
952 | # Get the target for loss depending on the prediction type
953 | if args.prediction_type is not None:
954 | # set prediction_type of scheduler if defined
955 | noise_scheduler.register_to_config(prediction_type=args.prediction_type)
956 |
957 | if noise_scheduler.config.prediction_type == "epsilon":
958 | target = noise
959 | elif noise_scheduler.config.prediction_type == "v_prediction":
960 | target = noise_scheduler.get_velocity(latents, noise, timesteps)
961 | else:
962 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
963 |
964 | # Predict the noise residual and compute loss
965 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
966 |
967 | if args.snr_gamma is None:
968 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
969 | else:
970 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
971 | # Since we predict the noise instead of x_0, the original formulation is slightly changed.
972 | # This is discussed in Section 4.2 of the same paper.
973 | snr = compute_snr(noise_scheduler, timesteps)
974 | if noise_scheduler.config.prediction_type == "v_prediction":
975 | # Velocity objective requires that we add one to SNR values before we divide by them.
976 | snr = snr + 1
977 | mse_loss_weights = (
978 | torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
979 | )
980 |
981 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
982 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
983 | loss = loss.mean()
984 |
985 | # Gather the losses across all processes for logging (if we use distributed training).
986 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
987 | train_loss += avg_loss.item() / args.gradient_accumulation_steps
988 |
989 | # Backpropagate
990 | accelerator.backward(loss)
991 | if accelerator.sync_gradients:
992 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
993 | optimizer.step()
994 | lr_scheduler.step()
995 | optimizer.zero_grad()
996 |
997 | # Checks if the accelerator has performed an optimization step behind the scenes
998 | if accelerator.sync_gradients:
999 | if args.use_ema:
1000 | ema_unet.step(unet.parameters())
1001 | progress_bar.update(1)
1002 | global_step += 1
1003 | accelerator.log({"train_loss": train_loss}, step=global_step)
1004 | train_loss = 0.0
1005 |
1006 | if global_step % args.checkpointing_steps == 0:
1007 | if accelerator.is_main_process:
1008 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1009 | if args.checkpoints_total_limit is not None:
1010 | checkpoints = os.listdir(args.output_dir)
1011 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1012 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1013 |
1014 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1015 | if len(checkpoints) >= args.checkpoints_total_limit:
1016 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1017 | removing_checkpoints = checkpoints[0:num_to_remove]
1018 |
1019 | logger.info(
1020 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1021 | )
1022 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1023 |
1024 | for removing_checkpoint in removing_checkpoints:
1025 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1026 | shutil.rmtree(removing_checkpoint)
1027 |
1028 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1029 | accelerator.save_state(save_path)
1030 | logger.info(f"Saved state to {save_path}")
1031 |
1032 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1033 | progress_bar.set_postfix(**logs)
1034 |
1035 | if global_step >= args.max_train_steps:
1036 | break
1037 |
1038 | if accelerator.is_main_process:
1039 | if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
1040 | if args.use_ema:
1041 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
1042 | ema_unet.store(unet.parameters())
1043 | ema_unet.copy_to(unet.parameters())
1044 | log_validation(
1045 | vae,
1046 | text_encoder,
1047 | tokenizer,
1048 | unet,
1049 | args,
1050 | accelerator,
1051 | weight_dtype,
1052 | global_step,
1053 | )
1054 | if args.use_ema:
1055 | # Switch back to the original UNet parameters.
1056 | ema_unet.restore(unet.parameters())
1057 |
1058 | if args.validation_prompts_path is not None and global_step % args.validation_steps == 0:
1059 | pipeline = StableDiffusionPipeline.from_pretrained(
1060 | args.pretrained_model_name_or_path,
1061 | vae=accelerator.unwrap_model(vae),
1062 | text_encoder=accelerator.unwrap_model(text_encoder),
1063 | tokenizer=tokenizer,
1064 | unet=accelerator.unwrap_model(unet),
1065 | safety_checker=None,
1066 | revision=args.revision,
1067 | variant=args.variant,
1068 | torch_dtype=weight_dtype,
1069 | )
1070 | pipeline = pipeline.to(accelerator.device)
1071 | pipeline.set_progress_bar_config(disable=True)
1072 | if args.enable_xformers_memory_efficient_attention:
1073 | pipeline.enable_xformers_memory_efficient_attention()
1074 |
1075 | with open(args.validation_prompts_path,'r',encoding='utf-8') as f:
1076 | prompts = json.load(f)[:25]
1077 |
1078 | score = get_score(pipeline, clip_model, clip_processor, prompts, args.validation_batch_size, f'{args.validation_output_dir}/{global_step}')
1079 | print(f'{global_step}: {score}')
1080 |
1081 | # Create the pipeline using the trained modules and save it.
1082 | accelerator.wait_for_everyone()
1083 | if accelerator.is_main_process:
1084 | unet = accelerator.unwrap_model(unet)
1085 | if args.use_ema:
1086 | ema_unet.copy_to(unet.parameters())
1087 |
1088 | pipeline = StableDiffusionPipeline.from_pretrained(
1089 | args.pretrained_model_name_or_path,
1090 | text_encoder=text_encoder,
1091 | vae=vae,
1092 | unet=unet,
1093 | revision=args.revision,
1094 | variant=args.variant,
1095 | )
1096 | pipeline.save_pretrained(args.output_dir)
1097 |
1098 | # Run a final round of inference.
1099 | images = []
1100 | if args.validation_prompts is not None:
1101 | logger.info("Running inference for collecting generated images...")
1102 | pipeline = pipeline.to(accelerator.device)
1103 | pipeline.torch_dtype = weight_dtype
1104 | pipeline.set_progress_bar_config(disable=True)
1105 |
1106 | if args.enable_xformers_memory_efficient_attention:
1107 | pipeline.enable_xformers_memory_efficient_attention()
1108 |
1109 | if args.seed is None:
1110 | generator = None
1111 | else:
1112 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
1113 |
1114 | for i in range(len(args.validation_prompts)):
1115 | with torch.autocast("cuda"):
1116 | image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
1117 | images.append(image)
1118 |
1119 | if args.push_to_hub:
1120 | save_model_card(args, repo_id, images, repo_folder=args.output_dir)
1121 | upload_folder(
1122 | repo_id=repo_id,
1123 | folder_path=args.output_dir,
1124 | commit_message="End of training",
1125 | ignore_patterns=["step_*", "epoch_*"],
1126 | )
1127 |
1128 | accelerator.end_training()
1129 |
1130 |
1131 | if __name__ == "__main__":
1132 | main()
1133 |
--------------------------------------------------------------------------------
/code/others/train_text_to_image_sdxl.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Fine-tuning script for Stable Diffusion XL for text2image."""
17 |
18 | import argparse
19 | import functools
20 | import gc
21 | import logging
22 | import math
23 | import os
24 | import random
25 | import shutil
26 | from pathlib import Path
27 |
28 | import accelerate
29 | import datasets
30 | import numpy as np
31 | import torch
32 | import torch.nn.functional as F
33 | import torch.utils.checkpoint
34 | import transformers
35 | from accelerate import Accelerator
36 | from accelerate.logging import get_logger
37 | from accelerate.utils import ProjectConfiguration, set_seed
38 | from datasets import load_dataset
39 | from huggingface_hub import create_repo, upload_folder
40 | from packaging import version
41 | from torchvision import transforms
42 | from torchvision.transforms.functional import crop
43 | from tqdm.auto import tqdm
44 | from transformers import AutoTokenizer, PretrainedConfig
45 |
46 | import diffusers
47 | from diffusers import (
48 | AutoencoderKL,
49 | DDPMScheduler,
50 | StableDiffusionXLPipeline,
51 | UNet2DConditionModel,
52 | )
53 | from diffusers.optimization import get_scheduler
54 | from diffusers.training_utils import EMAModel, compute_snr
55 | from diffusers.utils import check_min_version, is_wandb_available
56 | from diffusers.utils.import_utils import is_xformers_available
57 |
58 |
59 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
60 | check_min_version("0.25.0.dev0")
61 |
62 | logger = get_logger(__name__)
63 |
64 |
65 | DATASET_NAME_MAPPING = {
66 | "lambdalabs/pokemon-blip-captions": ("image", "text"),
67 | }
68 |
69 |
70 | def save_model_card(
71 | repo_id: str,
72 | images=None,
73 | validation_prompt=None,
74 | base_model=str,
75 | dataset_name=str,
76 | repo_folder=None,
77 | vae_path=None,
78 | ):
79 | img_str = ""
80 | for i, image in enumerate(images):
81 | image.save(os.path.join(repo_folder, f"image_{i}.png"))
82 | img_str += f"\n"
83 |
84 | yaml = f"""
85 | ---
86 | license: creativeml-openrail-m
87 | base_model: {base_model}
88 | dataset: {dataset_name}
89 | tags:
90 | - stable-diffusion-xl
91 | - stable-diffusion-xl-diffusers
92 | - text-to-image
93 | - diffusers
94 | inference: true
95 | ---
96 | """
97 | model_card = f"""
98 | # Text-to-image finetuning - {repo_id}
99 |
100 | This pipeline was finetuned from **{base_model}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompt: {validation_prompt}: \n
101 | {img_str}
102 |
103 | Special VAE used for training: {vae_path}.
104 | """
105 | with open(os.path.join(repo_folder, "README.md"), "w") as f:
106 | f.write(yaml + model_card)
107 |
108 |
109 | def import_model_class_from_model_name_or_path(
110 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
111 | ):
112 | text_encoder_config = PretrainedConfig.from_pretrained(
113 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision
114 | )
115 | model_class = text_encoder_config.architectures[0]
116 |
117 | if model_class == "CLIPTextModel":
118 | from transformers import CLIPTextModel
119 |
120 | return CLIPTextModel
121 | elif model_class == "CLIPTextModelWithProjection":
122 | from transformers import CLIPTextModelWithProjection
123 |
124 | return CLIPTextModelWithProjection
125 | else:
126 | raise ValueError(f"{model_class} is not supported.")
127 |
128 |
129 | def parse_args(input_args=None):
130 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
131 | parser.add_argument(
132 | "--pretrained_model_name_or_path",
133 | type=str,
134 | default=None,
135 | required=True,
136 | help="Path to pretrained model or model identifier from huggingface.co/models.",
137 | )
138 | parser.add_argument(
139 | "--pretrained_vae_model_name_or_path",
140 | type=str,
141 | default=None,
142 | help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
143 | )
144 | parser.add_argument(
145 | "--revision",
146 | type=str,
147 | default=None,
148 | required=False,
149 | help="Revision of pretrained model identifier from huggingface.co/models.",
150 | )
151 | parser.add_argument(
152 | "--variant",
153 | type=str,
154 | default=None,
155 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
156 | )
157 | parser.add_argument(
158 | "--dataset_name",
159 | type=str,
160 | default=None,
161 | help=(
162 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
163 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
164 | " or to a folder containing files that 🤗 Datasets can understand."
165 | ),
166 | )
167 | parser.add_argument(
168 | "--dataset_config_name",
169 | type=str,
170 | default=None,
171 | help="The config of the Dataset, leave as None if there's only one config.",
172 | )
173 | parser.add_argument(
174 | "--train_data_dir",
175 | type=str,
176 | default=None,
177 | help=(
178 | "A folder containing the training data. Folder contents must follow the structure described in"
179 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
180 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
181 | ),
182 | )
183 | parser.add_argument(
184 | "--image_column", type=str, default="image", help="The column of the dataset containing an image."
185 | )
186 | parser.add_argument(
187 | "--caption_column",
188 | type=str,
189 | default="text",
190 | help="The column of the dataset containing a caption or a list of captions.",
191 | )
192 | parser.add_argument(
193 | "--validation_prompt",
194 | type=str,
195 | default=None,
196 | help="A prompt that is used during validation to verify that the model is learning.",
197 | )
198 | parser.add_argument(
199 | "--num_validation_images",
200 | type=int,
201 | default=4,
202 | help="Number of images that should be generated during validation with `validation_prompt`.",
203 | )
204 | parser.add_argument(
205 | "--validation_epochs",
206 | type=int,
207 | default=1,
208 | help=(
209 | "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
210 | " `args.validation_prompt` multiple times: `args.num_validation_images`."
211 | ),
212 | )
213 | parser.add_argument(
214 | "--max_train_samples",
215 | type=int,
216 | default=None,
217 | help=(
218 | "For debugging purposes or quicker training, truncate the number of training examples to this "
219 | "value if set."
220 | ),
221 | )
222 | parser.add_argument(
223 | "--proportion_empty_prompts",
224 | type=float,
225 | default=0,
226 | help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
227 | )
228 | parser.add_argument(
229 | "--output_dir",
230 | type=str,
231 | default="sdxl-model-finetuned",
232 | help="The output directory where the model predictions and checkpoints will be written.",
233 | )
234 | parser.add_argument(
235 | "--cache_dir",
236 | type=str,
237 | default=None,
238 | help="The directory where the downloaded models and datasets will be stored.",
239 | )
240 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
241 | parser.add_argument(
242 | "--resolution",
243 | type=int,
244 | default=1024,
245 | help=(
246 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
247 | " resolution"
248 | ),
249 | )
250 | parser.add_argument(
251 | "--center_crop",
252 | default=False,
253 | action="store_true",
254 | help=(
255 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
256 | " cropped. The images will be resized to the resolution first before cropping."
257 | ),
258 | )
259 | parser.add_argument(
260 | "--random_flip",
261 | action="store_true",
262 | help="whether to randomly flip images horizontally",
263 | )
264 | parser.add_argument(
265 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
266 | )
267 | parser.add_argument("--num_train_epochs", type=int, default=100)
268 | parser.add_argument(
269 | "--max_train_steps",
270 | type=int,
271 | default=None,
272 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
273 | )
274 | parser.add_argument(
275 | "--checkpointing_steps",
276 | type=int,
277 | default=500,
278 | help=(
279 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
280 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
281 | " training using `--resume_from_checkpoint`."
282 | ),
283 | )
284 | parser.add_argument(
285 | "--checkpoints_total_limit",
286 | type=int,
287 | default=None,
288 | help=("Max number of checkpoints to store."),
289 | )
290 | parser.add_argument(
291 | "--resume_from_checkpoint",
292 | type=str,
293 | default=None,
294 | help=(
295 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
296 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
297 | ),
298 | )
299 | parser.add_argument(
300 | "--gradient_accumulation_steps",
301 | type=int,
302 | default=1,
303 | help="Number of updates steps to accumulate before performing a backward/update pass.",
304 | )
305 | parser.add_argument(
306 | "--gradient_checkpointing",
307 | action="store_true",
308 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
309 | )
310 | parser.add_argument(
311 | "--learning_rate",
312 | type=float,
313 | default=1e-4,
314 | help="Initial learning rate (after the potential warmup period) to use.",
315 | )
316 | parser.add_argument(
317 | "--scale_lr",
318 | action="store_true",
319 | default=False,
320 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
321 | )
322 | parser.add_argument(
323 | "--lr_scheduler",
324 | type=str,
325 | default="constant",
326 | help=(
327 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
328 | ' "constant", "constant_with_warmup"]'
329 | ),
330 | )
331 | parser.add_argument(
332 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
333 | )
334 | parser.add_argument(
335 | "--timestep_bias_strategy",
336 | type=str,
337 | default="none",
338 | choices=["earlier", "later", "range", "none"],
339 | help=(
340 | "The timestep bias strategy, which may help direct the model toward learning low or high frequency details."
341 | " Choices: ['earlier', 'later', 'range', 'none']."
342 | " The default is 'none', which means no bias is applied, and training proceeds normally."
343 | " The value of 'later' will increase the frequency of the model's final training timesteps."
344 | ),
345 | )
346 | parser.add_argument(
347 | "--timestep_bias_multiplier",
348 | type=float,
349 | default=1.0,
350 | help=(
351 | "The multiplier for the bias. Defaults to 1.0, which means no bias is applied."
352 | " A value of 2.0 will double the weight of the bias, and a value of 0.5 will halve it."
353 | ),
354 | )
355 | parser.add_argument(
356 | "--timestep_bias_begin",
357 | type=int,
358 | default=0,
359 | help=(
360 | "When using `--timestep_bias_strategy=range`, the beginning (inclusive) timestep to bias."
361 | " Defaults to zero, which equates to having no specific bias."
362 | ),
363 | )
364 | parser.add_argument(
365 | "--timestep_bias_end",
366 | type=int,
367 | default=1000,
368 | help=(
369 | "When using `--timestep_bias_strategy=range`, the final timestep (inclusive) to bias."
370 | " Defaults to 1000, which is the number of timesteps that Stable Diffusion is trained on."
371 | ),
372 | )
373 | parser.add_argument(
374 | "--timestep_bias_portion",
375 | type=float,
376 | default=0.25,
377 | help=(
378 | "The portion of timesteps to bias. Defaults to 0.25, which 25% of timesteps will be biased."
379 | " A value of 0.5 will bias one half of the timesteps. The value provided for `--timestep_bias_strategy` determines"
380 | " whether the biased portions are in the earlier or later timesteps."
381 | ),
382 | )
383 | parser.add_argument(
384 | "--snr_gamma",
385 | type=float,
386 | default=None,
387 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
388 | "More details here: https://arxiv.org/abs/2303.09556.",
389 | )
390 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
391 | parser.add_argument(
392 | "--allow_tf32",
393 | action="store_true",
394 | help=(
395 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
396 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
397 | ),
398 | )
399 | parser.add_argument(
400 | "--dataloader_num_workers",
401 | type=int,
402 | default=0,
403 | help=(
404 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
405 | ),
406 | )
407 | parser.add_argument(
408 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
409 | )
410 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
411 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
412 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
413 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
414 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
415 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
416 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
417 | parser.add_argument(
418 | "--prediction_type",
419 | type=str,
420 | default=None,
421 | help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
422 | )
423 | parser.add_argument(
424 | "--hub_model_id",
425 | type=str,
426 | default=None,
427 | help="The name of the repository to keep in sync with the local `output_dir`.",
428 | )
429 | parser.add_argument(
430 | "--logging_dir",
431 | type=str,
432 | default="logs",
433 | help=(
434 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
435 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
436 | ),
437 | )
438 | parser.add_argument(
439 | "--report_to",
440 | type=str,
441 | default="tensorboard",
442 | help=(
443 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
444 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
445 | ),
446 | )
447 | parser.add_argument(
448 | "--mixed_precision",
449 | type=str,
450 | default=None,
451 | choices=["no", "fp16", "bf16"],
452 | help=(
453 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
454 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
455 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
456 | ),
457 | )
458 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
459 | parser.add_argument(
460 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
461 | )
462 | parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
463 |
464 | if input_args is not None:
465 | args = parser.parse_args(input_args)
466 | else:
467 | args = parser.parse_args()
468 |
469 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
470 | if env_local_rank != -1 and env_local_rank != args.local_rank:
471 | args.local_rank = env_local_rank
472 |
473 | # Sanity checks
474 | if args.dataset_name is None and args.train_data_dir is None:
475 | raise ValueError("Need either a dataset name or a training folder.")
476 |
477 | if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
478 | raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
479 |
480 | return args
481 |
482 |
483 | # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
484 | def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, caption_column, is_train=True):
485 | prompt_embeds_list = []
486 | prompt_batch = batch[caption_column]
487 |
488 | captions = []
489 | for caption in prompt_batch:
490 | if random.random() < proportion_empty_prompts:
491 | captions.append("")
492 | elif isinstance(caption, str):
493 | captions.append(caption)
494 | elif isinstance(caption, (list, np.ndarray)):
495 | # take a random caption if there are multiple
496 | captions.append(random.choice(caption) if is_train else caption[0])
497 |
498 | with torch.no_grad():
499 | for tokenizer, text_encoder in zip(tokenizers, text_encoders):
500 | text_inputs = tokenizer(
501 | captions,
502 | padding="max_length",
503 | max_length=tokenizer.model_max_length,
504 | truncation=True,
505 | return_tensors="pt",
506 | )
507 | text_input_ids = text_inputs.input_ids
508 | prompt_embeds = text_encoder(
509 | text_input_ids.to(text_encoder.device),
510 | output_hidden_states=True,
511 | )
512 |
513 | # We are only ALWAYS interested in the pooled output of the final text encoder
514 | pooled_prompt_embeds = prompt_embeds[0]
515 | prompt_embeds = prompt_embeds.hidden_states[-2]
516 | bs_embed, seq_len, _ = prompt_embeds.shape
517 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
518 | prompt_embeds_list.append(prompt_embeds)
519 |
520 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
521 | pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
522 | return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
523 |
524 |
525 | def compute_vae_encodings(batch, vae):
526 | images = batch.pop("pixel_values")
527 | pixel_values = torch.stack(list(images))
528 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
529 | pixel_values = pixel_values.to(vae.device, dtype=vae.dtype)
530 |
531 | with torch.no_grad():
532 | model_input = vae.encode(pixel_values).latent_dist.sample()
533 | model_input = model_input * vae.config.scaling_factor
534 | return {"model_input": model_input.cpu()}
535 |
536 |
537 | def generate_timestep_weights(args, num_timesteps):
538 | weights = torch.ones(num_timesteps)
539 |
540 | # Determine the indices to bias
541 | num_to_bias = int(args.timestep_bias_portion * num_timesteps)
542 |
543 | if args.timestep_bias_strategy == "later":
544 | bias_indices = slice(-num_to_bias, None)
545 | elif args.timestep_bias_strategy == "earlier":
546 | bias_indices = slice(0, num_to_bias)
547 | elif args.timestep_bias_strategy == "range":
548 | # Out of the possible 1000 timesteps, we might want to focus on eg. 200-500.
549 | range_begin = args.timestep_bias_begin
550 | range_end = args.timestep_bias_end
551 | if range_begin < 0:
552 | raise ValueError(
553 | "When using the range strategy for timestep bias, you must provide a beginning timestep greater or equal to zero."
554 | )
555 | if range_end > num_timesteps:
556 | raise ValueError(
557 | "When using the range strategy for timestep bias, you must provide an ending timestep smaller than the number of timesteps."
558 | )
559 | bias_indices = slice(range_begin, range_end)
560 | else: # 'none' or any other string
561 | return weights
562 | if args.timestep_bias_multiplier <= 0:
563 | return ValueError(
564 | "The parameter --timestep_bias_multiplier is not intended to be used to disable the training of specific timesteps."
565 | " If it was intended to disable timestep bias, use `--timestep_bias_strategy none` instead."
566 | " A timestep bias multiplier less than or equal to 0 is not allowed."
567 | )
568 |
569 | # Apply the bias
570 | weights[bias_indices] *= args.timestep_bias_multiplier
571 |
572 | # Normalize
573 | weights /= weights.sum()
574 |
575 | return weights
576 |
577 |
578 | def main(args):
579 | logging_dir = Path(args.output_dir, args.logging_dir)
580 |
581 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
582 |
583 | accelerator = Accelerator(
584 | gradient_accumulation_steps=args.gradient_accumulation_steps,
585 | mixed_precision=args.mixed_precision,
586 | log_with=args.report_to,
587 | project_config=accelerator_project_config,
588 | )
589 |
590 | if args.report_to == "wandb":
591 | if not is_wandb_available():
592 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
593 | import wandb
594 |
595 | # Make one log on every process with the configuration for debugging.
596 | logging.basicConfig(
597 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
598 | datefmt="%m/%d/%Y %H:%M:%S",
599 | level=logging.INFO,
600 | )
601 | logger.info(accelerator.state, main_process_only=False)
602 | if accelerator.is_local_main_process:
603 | datasets.utils.logging.set_verbosity_warning()
604 | transformers.utils.logging.set_verbosity_warning()
605 | diffusers.utils.logging.set_verbosity_info()
606 | else:
607 | datasets.utils.logging.set_verbosity_error()
608 | transformers.utils.logging.set_verbosity_error()
609 | diffusers.utils.logging.set_verbosity_error()
610 |
611 | # If passed along, set the training seed now.
612 | if args.seed is not None:
613 | set_seed(args.seed)
614 |
615 | # Handle the repository creation
616 | if accelerator.is_main_process:
617 | if args.output_dir is not None:
618 | os.makedirs(args.output_dir, exist_ok=True)
619 |
620 | if args.push_to_hub:
621 | repo_id = create_repo(
622 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
623 | ).repo_id
624 |
625 | # Load the tokenizers
626 | tokenizer_one = AutoTokenizer.from_pretrained(
627 | args.pretrained_model_name_or_path,
628 | subfolder="tokenizer",
629 | revision=args.revision,
630 | use_fast=False,
631 | )
632 | tokenizer_two = AutoTokenizer.from_pretrained(
633 | args.pretrained_model_name_or_path,
634 | subfolder="tokenizer_2",
635 | revision=args.revision,
636 | use_fast=False,
637 | )
638 |
639 | # import correct text encoder classes
640 | text_encoder_cls_one = import_model_class_from_model_name_or_path(
641 | args.pretrained_model_name_or_path, args.revision
642 | )
643 | text_encoder_cls_two = import_model_class_from_model_name_or_path(
644 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
645 | )
646 |
647 | # Load scheduler and models
648 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
649 | # Check for terminal SNR in combination with SNR Gamma
650 | text_encoder_one = text_encoder_cls_one.from_pretrained(
651 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
652 | )
653 | text_encoder_two = text_encoder_cls_two.from_pretrained(
654 | args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
655 | )
656 | vae_path = (
657 | args.pretrained_model_name_or_path
658 | if args.pretrained_vae_model_name_or_path is None
659 | else args.pretrained_vae_model_name_or_path
660 | )
661 | vae = AutoencoderKL.from_pretrained(
662 | vae_path,
663 | subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
664 | revision=args.revision,
665 | variant=args.variant,
666 | )
667 | unet = UNet2DConditionModel.from_pretrained(
668 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
669 | )
670 |
671 | # Freeze vae and text encoders.
672 | vae.requires_grad_(False)
673 | text_encoder_one.requires_grad_(False)
674 | text_encoder_two.requires_grad_(False)
675 | # Set unet as trainable.
676 | unet.train()
677 |
678 | # For mixed precision training we cast all non-trainable weigths to half-precision
679 | # as these weights are only used for inference, keeping weights in full precision is not required.
680 | weight_dtype = torch.float32
681 | if accelerator.mixed_precision == "fp16":
682 | weight_dtype = torch.float16
683 | elif accelerator.mixed_precision == "bf16":
684 | weight_dtype = torch.bfloat16
685 |
686 | # Move unet, vae and text_encoder to device and cast to weight_dtype
687 | # The VAE is in float32 to avoid NaN losses.
688 | vae.to(accelerator.device, dtype=torch.float32)
689 | text_encoder_one.to(accelerator.device, dtype=weight_dtype)
690 | text_encoder_two.to(accelerator.device, dtype=weight_dtype)
691 |
692 | # Create EMA for the unet.
693 | if args.use_ema:
694 | ema_unet = UNet2DConditionModel.from_pretrained(
695 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
696 | )
697 | ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
698 |
699 | if args.enable_xformers_memory_efficient_attention:
700 | if is_xformers_available():
701 | import xformers
702 |
703 | xformers_version = version.parse(xformers.__version__)
704 | if xformers_version == version.parse("0.0.16"):
705 | logger.warn(
706 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
707 | )
708 | unet.enable_xformers_memory_efficient_attention()
709 | else:
710 | raise ValueError("xformers is not available. Make sure it is installed correctly")
711 |
712 | # `accelerate` 0.16.0 will have better support for customized saving
713 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
714 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
715 | def save_model_hook(models, weights, output_dir):
716 | if accelerator.is_main_process:
717 | if args.use_ema:
718 | ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
719 |
720 | for i, model in enumerate(models):
721 | model.save_pretrained(os.path.join(output_dir, "unet"))
722 |
723 | # make sure to pop weight so that corresponding model is not saved again
724 | weights.pop()
725 |
726 | def load_model_hook(models, input_dir):
727 | if args.use_ema:
728 | load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
729 | ema_unet.load_state_dict(load_model.state_dict())
730 | ema_unet.to(accelerator.device)
731 | del load_model
732 |
733 | for i in range(len(models)):
734 | # pop models so that they are not loaded again
735 | model = models.pop()
736 |
737 | # load diffusers style into model
738 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
739 | model.register_to_config(**load_model.config)
740 |
741 | model.load_state_dict(load_model.state_dict())
742 | del load_model
743 |
744 | accelerator.register_save_state_pre_hook(save_model_hook)
745 | accelerator.register_load_state_pre_hook(load_model_hook)
746 |
747 | if args.gradient_checkpointing:
748 | unet.enable_gradient_checkpointing()
749 |
750 | # Enable TF32 for faster training on Ampere GPUs,
751 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
752 | if args.allow_tf32:
753 | torch.backends.cuda.matmul.allow_tf32 = True
754 |
755 | if args.scale_lr:
756 | args.learning_rate = (
757 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
758 | )
759 |
760 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
761 | if args.use_8bit_adam:
762 | try:
763 | import bitsandbytes as bnb
764 | except ImportError:
765 | raise ImportError(
766 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
767 | )
768 |
769 | optimizer_class = bnb.optim.AdamW8bit
770 | else:
771 | optimizer_class = torch.optim.AdamW
772 |
773 | # Optimizer creation
774 | params_to_optimize = unet.parameters()
775 | optimizer = optimizer_class(
776 | params_to_optimize,
777 | lr=args.learning_rate,
778 | betas=(args.adam_beta1, args.adam_beta2),
779 | weight_decay=args.adam_weight_decay,
780 | eps=args.adam_epsilon,
781 | )
782 |
783 | # Get the datasets: you can either provide your own training and evaluation files (see below)
784 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
785 |
786 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently
787 | # download the dataset.
788 | if args.dataset_name is not None:
789 | # Downloading and loading a dataset from the hub.
790 | dataset = load_dataset(
791 | args.dataset_name,
792 | args.dataset_config_name,
793 | cache_dir=args.cache_dir,
794 | )
795 | else:
796 | data_files = {}
797 | if args.train_data_dir is not None:
798 | data_files["train"] = os.path.join(args.train_data_dir, "**")
799 | dataset = load_dataset(
800 | "imagefolder",
801 | data_files=data_files,
802 | cache_dir=args.cache_dir,
803 | )
804 | # See more about loading custom images at
805 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
806 |
807 | # Preprocessing the datasets.
808 | # We need to tokenize inputs and targets.
809 | column_names = dataset["train"].column_names
810 |
811 | # 6. Get the column names for input/target.
812 | dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
813 | if args.image_column is None:
814 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
815 | else:
816 | image_column = args.image_column
817 | if image_column not in column_names:
818 | raise ValueError(
819 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
820 | )
821 | if args.caption_column is None:
822 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
823 | else:
824 | caption_column = args.caption_column
825 | if caption_column not in column_names:
826 | raise ValueError(
827 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
828 | )
829 |
830 | # Preprocessing the datasets.
831 | train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
832 | train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
833 | train_flip = transforms.RandomHorizontalFlip(p=1.0)
834 | train_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
835 |
836 | def preprocess_train(examples):
837 | images = [image.convert("RGB") for image in examples[image_column]]
838 | # image aug
839 | original_sizes = []
840 | all_images = []
841 | crop_top_lefts = []
842 | for image in images:
843 | original_sizes.append((image.height, image.width))
844 | image = train_resize(image)
845 | if args.center_crop:
846 | y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
847 | x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
848 | image = train_crop(image)
849 | else:
850 | y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
851 | image = crop(image, y1, x1, h, w)
852 | if args.random_flip and random.random() < 0.5:
853 | # flip
854 | x1 = image.width - x1
855 | image = train_flip(image)
856 | crop_top_left = (y1, x1)
857 | crop_top_lefts.append(crop_top_left)
858 | image = train_transforms(image)
859 | all_images.append(image)
860 |
861 | examples["original_sizes"] = original_sizes
862 | examples["crop_top_lefts"] = crop_top_lefts
863 | examples["pixel_values"] = all_images
864 | return examples
865 |
866 | with accelerator.main_process_first():
867 | if args.max_train_samples is not None:
868 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
869 | # Set the training transforms
870 | train_dataset = dataset["train"].with_transform(preprocess_train)
871 |
872 | # Let's first compute all the embeddings so that we can free up the text encoders
873 | # from memory. We will pre-compute the VAE encodings too.
874 | text_encoders = [text_encoder_one, text_encoder_two]
875 | tokenizers = [tokenizer_one, tokenizer_two]
876 | compute_embeddings_fn = functools.partial(
877 | encode_prompt,
878 | text_encoders=text_encoders,
879 | tokenizers=tokenizers,
880 | proportion_empty_prompts=args.proportion_empty_prompts,
881 | caption_column=args.caption_column,
882 | )
883 | compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
884 | with accelerator.main_process_first():
885 | from datasets.fingerprint import Hasher
886 |
887 | # fingerprint used by the cache for the other processes to load the result
888 | # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
889 | new_fingerprint = Hasher.hash(args)
890 | new_fingerprint_for_vae = Hasher.hash("vae")
891 | train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
892 | train_dataset = train_dataset.map(
893 | compute_vae_encodings_fn,
894 | batched=True,
895 | batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
896 | new_fingerprint=new_fingerprint_for_vae,
897 | )
898 |
899 | del text_encoders, tokenizers, vae
900 | gc.collect()
901 | torch.cuda.empty_cache()
902 |
903 | def collate_fn(examples):
904 | model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
905 | original_sizes = [example["original_sizes"] for example in examples]
906 | crop_top_lefts = [example["crop_top_lefts"] for example in examples]
907 | prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples])
908 | pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples])
909 |
910 | return {
911 | "model_input": model_input,
912 | "prompt_embeds": prompt_embeds,
913 | "pooled_prompt_embeds": pooled_prompt_embeds,
914 | "original_sizes": original_sizes,
915 | "crop_top_lefts": crop_top_lefts,
916 | }
917 |
918 | # DataLoaders creation:
919 | train_dataloader = torch.utils.data.DataLoader(
920 | train_dataset,
921 | shuffle=True,
922 | collate_fn=collate_fn,
923 | batch_size=args.train_batch_size,
924 | num_workers=args.dataloader_num_workers,
925 | )
926 |
927 | # Scheduler and math around the number of training steps.
928 | overrode_max_train_steps = False
929 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
930 | if args.max_train_steps is None:
931 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
932 | overrode_max_train_steps = True
933 |
934 | lr_scheduler = get_scheduler(
935 | args.lr_scheduler,
936 | optimizer=optimizer,
937 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
938 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
939 | )
940 |
941 | # Prepare everything with our `accelerator`.
942 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
943 | unet, optimizer, train_dataloader, lr_scheduler
944 | )
945 |
946 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
947 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
948 | if overrode_max_train_steps:
949 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
950 | # Afterwards we recalculate our number of training epochs
951 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
952 |
953 | # We need to initialize the trackers we use, and also store our configuration.
954 | # The trackers initializes automatically on the main process.
955 | if accelerator.is_main_process:
956 | accelerator.init_trackers("text2image-fine-tune-sdxl", config=vars(args))
957 |
958 | # Train!
959 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
960 |
961 | logger.info("***** Running training *****")
962 | logger.info(f" Num examples = {len(train_dataset)}")
963 | logger.info(f" Num Epochs = {args.num_train_epochs}")
964 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
965 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
966 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
967 | logger.info(f" Total optimization steps = {args.max_train_steps}")
968 | global_step = 0
969 | first_epoch = 0
970 |
971 | # Potentially load in the weights and states from a previous save
972 | if args.resume_from_checkpoint:
973 | if args.resume_from_checkpoint != "latest":
974 | path = os.path.basename(args.resume_from_checkpoint)
975 | else:
976 | # Get the most recent checkpoint
977 | dirs = os.listdir(args.output_dir)
978 | dirs = [d for d in dirs if d.startswith("checkpoint")]
979 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
980 | path = dirs[-1] if len(dirs) > 0 else None
981 |
982 | if path is None:
983 | accelerator.print(
984 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
985 | )
986 | args.resume_from_checkpoint = None
987 | initial_global_step = 0
988 | else:
989 | accelerator.print(f"Resuming from checkpoint {path}")
990 | accelerator.load_state(os.path.join(args.output_dir, path))
991 | global_step = int(path.split("-")[1])
992 |
993 | initial_global_step = global_step
994 | first_epoch = global_step // num_update_steps_per_epoch
995 |
996 | else:
997 | initial_global_step = 0
998 |
999 | progress_bar = tqdm(
1000 | range(0, args.max_train_steps),
1001 | initial=initial_global_step,
1002 | desc="Steps",
1003 | # Only show the progress bar once on each machine.
1004 | disable=not accelerator.is_local_main_process,
1005 | )
1006 |
1007 | for epoch in range(first_epoch, args.num_train_epochs):
1008 | train_loss = 0.0
1009 | for step, batch in enumerate(train_dataloader):
1010 | with accelerator.accumulate(unet):
1011 | # Sample noise that we'll add to the latents
1012 | model_input = batch["model_input"].to(accelerator.device)
1013 | noise = torch.randn_like(model_input)
1014 | if args.noise_offset:
1015 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise
1016 | noise += args.noise_offset * torch.randn(
1017 | (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
1018 | )
1019 |
1020 | bsz = model_input.shape[0]
1021 | if args.timestep_bias_strategy == "none":
1022 | # Sample a random timestep for each image without bias.
1023 | timesteps = torch.randint(
1024 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1025 | )
1026 | else:
1027 | # Sample a random timestep for each image, potentially biased by the timestep weights.
1028 | # Biasing the timestep weights allows us to spend less time training irrelevant timesteps.
1029 | weights = generate_timestep_weights(args, noise_scheduler.config.num_train_timesteps).to(
1030 | model_input.device
1031 | )
1032 | timesteps = torch.multinomial(weights, bsz, replacement=True).long()
1033 |
1034 | # Add noise to the model input according to the noise magnitude at each timestep
1035 | # (this is the forward diffusion process)
1036 | noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1037 |
1038 | # time ids
1039 | def compute_time_ids(original_size, crops_coords_top_left):
1040 | # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1041 | target_size = (args.resolution, args.resolution)
1042 | add_time_ids = list(original_size + crops_coords_top_left + target_size)
1043 | add_time_ids = torch.tensor([add_time_ids])
1044 | add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1045 | return add_time_ids
1046 |
1047 | add_time_ids = torch.cat(
1048 | [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
1049 | )
1050 |
1051 | # Predict the noise residual
1052 | unet_added_conditions = {"time_ids": add_time_ids}
1053 | prompt_embeds = batch["prompt_embeds"].to(accelerator.device)
1054 | pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device)
1055 | unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
1056 | model_pred = unet(
1057 | noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
1058 | ).sample
1059 |
1060 | # Get the target for loss depending on the prediction type
1061 | if args.prediction_type is not None:
1062 | # set prediction_type of scheduler if defined
1063 | noise_scheduler.register_to_config(prediction_type=args.prediction_type)
1064 |
1065 | if noise_scheduler.config.prediction_type == "epsilon":
1066 | target = noise
1067 | elif noise_scheduler.config.prediction_type == "v_prediction":
1068 | target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1069 | elif noise_scheduler.config.prediction_type == "sample":
1070 | # We set the target to latents here, but the model_pred will return the noise sample prediction.
1071 | target = model_input
1072 | # We will have to subtract the noise residual from the prediction to get the target sample.
1073 | model_pred = model_pred - noise
1074 | else:
1075 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1076 |
1077 | if args.snr_gamma is None:
1078 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1079 | else:
1080 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1081 | # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1082 | # This is discussed in Section 4.2 of the same paper.
1083 | snr = compute_snr(noise_scheduler, timesteps)
1084 | if noise_scheduler.config.prediction_type == "v_prediction":
1085 | # Velocity objective requires that we add one to SNR values before we divide by them.
1086 | snr = snr + 1
1087 | mse_loss_weights = (
1088 | torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1089 | )
1090 |
1091 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1092 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1093 | loss = loss.mean()
1094 |
1095 | # Gather the losses across all processes for logging (if we use distributed training).
1096 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1097 | train_loss += avg_loss.item() / args.gradient_accumulation_steps
1098 |
1099 | # Backpropagate
1100 | accelerator.backward(loss)
1101 | if accelerator.sync_gradients:
1102 | params_to_clip = unet.parameters()
1103 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1104 | optimizer.step()
1105 | lr_scheduler.step()
1106 | optimizer.zero_grad()
1107 |
1108 | # Checks if the accelerator has performed an optimization step behind the scenes
1109 | if accelerator.sync_gradients:
1110 | progress_bar.update(1)
1111 | global_step += 1
1112 | accelerator.log({"train_loss": train_loss}, step=global_step)
1113 | train_loss = 0.0
1114 |
1115 | if accelerator.is_main_process:
1116 | if global_step % args.checkpointing_steps == 0:
1117 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1118 | if args.checkpoints_total_limit is not None:
1119 | checkpoints = os.listdir(args.output_dir)
1120 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1121 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1122 |
1123 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1124 | if len(checkpoints) >= args.checkpoints_total_limit:
1125 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1126 | removing_checkpoints = checkpoints[0:num_to_remove]
1127 |
1128 | logger.info(
1129 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1130 | )
1131 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1132 |
1133 | for removing_checkpoint in removing_checkpoints:
1134 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1135 | shutil.rmtree(removing_checkpoint)
1136 |
1137 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1138 | accelerator.save_state(save_path)
1139 | logger.info(f"Saved state to {save_path}")
1140 |
1141 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1142 | progress_bar.set_postfix(**logs)
1143 |
1144 | if global_step >= args.max_train_steps:
1145 | break
1146 |
1147 | if accelerator.is_main_process:
1148 | if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1149 | logger.info(
1150 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1151 | f" {args.validation_prompt}."
1152 | )
1153 | if args.use_ema:
1154 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
1155 | ema_unet.store(unet.parameters())
1156 | ema_unet.copy_to(unet.parameters())
1157 |
1158 | # create pipeline
1159 | vae = AutoencoderKL.from_pretrained(
1160 | vae_path,
1161 | subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1162 | revision=args.revision,
1163 | variant=args.variant,
1164 | )
1165 | pipeline = StableDiffusionXLPipeline.from_pretrained(
1166 | args.pretrained_model_name_or_path,
1167 | vae=vae,
1168 | unet=accelerator.unwrap_model(unet),
1169 | revision=args.revision,
1170 | variant=args.variant,
1171 | torch_dtype=weight_dtype,
1172 | )
1173 | if args.prediction_type is not None:
1174 | scheduler_args = {"prediction_type": args.prediction_type}
1175 | pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1176 |
1177 | pipeline = pipeline.to(accelerator.device)
1178 | pipeline.set_progress_bar_config(disable=True)
1179 |
1180 | # run inference
1181 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1182 | pipeline_args = {"prompt": args.validation_prompt}
1183 |
1184 | with torch.cuda.amp.autocast():
1185 | images = [
1186 | pipeline(**pipeline_args, generator=generator, num_inference_steps=25).images[0]
1187 | for _ in range(args.num_validation_images)
1188 | ]
1189 |
1190 | for tracker in accelerator.trackers:
1191 | if tracker.name == "tensorboard":
1192 | np_images = np.stack([np.asarray(img) for img in images])
1193 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1194 | if tracker.name == "wandb":
1195 | tracker.log(
1196 | {
1197 | "validation": [
1198 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1199 | for i, image in enumerate(images)
1200 | ]
1201 | }
1202 | )
1203 |
1204 | del pipeline
1205 | torch.cuda.empty_cache()
1206 |
1207 | accelerator.wait_for_everyone()
1208 | if accelerator.is_main_process:
1209 | unet = accelerator.unwrap_model(unet)
1210 | if args.use_ema:
1211 | ema_unet.copy_to(unet.parameters())
1212 |
1213 | # Serialize pipeline.
1214 | vae = AutoencoderKL.from_pretrained(
1215 | vae_path,
1216 | subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1217 | revision=args.revision,
1218 | variant=args.variant,
1219 | torch_dtype=weight_dtype,
1220 | )
1221 | pipeline = StableDiffusionXLPipeline.from_pretrained(
1222 | args.pretrained_model_name_or_path,
1223 | unet=unet,
1224 | vae=vae,
1225 | revision=args.revision,
1226 | variant=args.variant,
1227 | torch_dtype=weight_dtype,
1228 | )
1229 | if args.prediction_type is not None:
1230 | scheduler_args = {"prediction_type": args.prediction_type}
1231 | pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1232 | pipeline.save_pretrained(args.output_dir)
1233 |
1234 | # run inference
1235 | images = []
1236 | if args.validation_prompt and args.num_validation_images > 0:
1237 | pipeline = pipeline.to(accelerator.device)
1238 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1239 | with torch.cuda.amp.autocast():
1240 | images = [
1241 | pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1242 | for _ in range(args.num_validation_images)
1243 | ]
1244 |
1245 | for tracker in accelerator.trackers:
1246 | if tracker.name == "tensorboard":
1247 | np_images = np.stack([np.asarray(img) for img in images])
1248 | tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1249 | if tracker.name == "wandb":
1250 | tracker.log(
1251 | {
1252 | "test": [
1253 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1254 | for i, image in enumerate(images)
1255 | ]
1256 | }
1257 | )
1258 |
1259 | if args.push_to_hub:
1260 | save_model_card(
1261 | repo_id=repo_id,
1262 | images=images,
1263 | validation_prompt=args.validation_prompt,
1264 | base_model=args.pretrained_model_name_or_path,
1265 | dataset_name=args.dataset_name,
1266 | repo_folder=args.output_dir,
1267 | vae_path=args.pretrained_vae_model_name_or_path,
1268 | )
1269 | upload_folder(
1270 | repo_id=repo_id,
1271 | folder_path=args.output_dir,
1272 | commit_message="End of training",
1273 | ignore_patterns=["step_*", "epoch_*"],
1274 | )
1275 |
1276 | accelerator.end_training()
1277 |
1278 |
1279 | if __name__ == "__main__":
1280 | args = parse_args()
1281 | main(args)
1282 |
--------------------------------------------------------------------------------
/code/others/train_dreambooth.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 |
16 | import argparse
17 | import copy
18 | import gc
19 | import importlib
20 | import itertools
21 | import logging
22 | import math
23 | import os
24 | import shutil
25 | import warnings
26 | from pathlib import Path
27 | import json
28 | import re
29 |
30 | import numpy as np
31 | import torch
32 | import torch.nn.functional as F
33 | import torch.utils.checkpoint
34 | import transformers
35 | from accelerate import Accelerator
36 | from accelerate.logging import get_logger
37 | from accelerate.utils import ProjectConfiguration, set_seed
38 | from huggingface_hub import create_repo, model_info, upload_folder
39 | from huggingface_hub.utils import insecure_hashlib
40 | from packaging import version
41 | from PIL import Image
42 | from PIL.ImageOps import exif_transpose
43 | from torch.utils.data import Dataset
44 | from torchvision import transforms
45 | from tqdm.auto import tqdm
46 | from transformers import AutoTokenizer, PretrainedConfig, CLIPProcessor, CLIPModel
47 |
48 | import diffusers
49 | from diffusers import (
50 | AutoencoderKL,
51 | DDPMScheduler,
52 | DiffusionPipeline,
53 | StableDiffusionPipeline,
54 | UNet2DConditionModel,
55 | )
56 | from diffusers.optimization import get_scheduler
57 | from diffusers.training_utils import compute_snr
58 | from diffusers.utils import check_min_version, is_wandb_available
59 | from diffusers.utils.import_utils import is_xformers_available
60 |
61 | from clip_score import get_score
62 |
63 |
64 | if is_wandb_available():
65 | import wandb
66 |
67 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
68 | check_min_version("0.25.0.dev0")
69 |
70 | logger = get_logger(__name__)
71 |
72 |
73 | def save_model_card(
74 | repo_id: str,
75 | images=None,
76 | base_model=str,
77 | train_text_encoder=False,
78 | prompt=str,
79 | repo_folder=None,
80 | pipeline: DiffusionPipeline = None,
81 | ):
82 | img_str = ""
83 | for i, image in enumerate(images):
84 | image.save(os.path.join(repo_folder, f"image_{i}.png"))
85 | img_str += f"\n"
86 |
87 | yaml = f"""
88 | ---
89 | license: creativeml-openrail-m
90 | base_model: {base_model}
91 | instance_prompt: {prompt}
92 | tags:
93 | - {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
94 | - {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
95 | - text-to-image
96 | - diffusers
97 | - dreambooth
98 | inference: true
99 | ---
100 | """
101 | model_card = f"""
102 | # DreamBooth - {repo_id}
103 |
104 | This is a dreambooth model derived from {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/).
105 | You can find some example images in the following. \n
106 | {img_str}
107 |
108 | DreamBooth for the text encoder was enabled: {train_text_encoder}.
109 | """
110 | with open(os.path.join(repo_folder, "README.md"), "w") as f:
111 | f.write(yaml + model_card)
112 |
113 |
114 | def log_validation(
115 | text_encoder,
116 | tokenizer,
117 | unet,
118 | vae,
119 | args,
120 | accelerator,
121 | weight_dtype,
122 | global_step,
123 | prompt_embeds,
124 | negative_prompt_embeds,
125 | ):
126 | logger.info(
127 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
128 | f" {args.validation_prompt}."
129 | )
130 |
131 | pipeline_args = {}
132 |
133 | if vae is not None:
134 | pipeline_args["vae"] = vae
135 |
136 | if text_encoder is not None:
137 | text_encoder = accelerator.unwrap_model(text_encoder)
138 |
139 | # create pipeline (note: unet and vae are loaded again in float32)
140 | pipeline = DiffusionPipeline.from_pretrained(
141 | args.pretrained_model_name_or_path,
142 | tokenizer=tokenizer,
143 | text_encoder=text_encoder,
144 | unet=accelerator.unwrap_model(unet),
145 | revision=args.revision,
146 | variant=args.variant,
147 | torch_dtype=weight_dtype,
148 | **pipeline_args,
149 | )
150 |
151 | # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
152 | scheduler_args = {}
153 |
154 | if "variance_type" in pipeline.scheduler.config:
155 | variance_type = pipeline.scheduler.config.variance_type
156 |
157 | if variance_type in ["learned", "learned_range"]:
158 | variance_type = "fixed_small"
159 |
160 | scheduler_args["variance_type"] = variance_type
161 |
162 | module = importlib.import_module("diffusers")
163 | scheduler_class = getattr(module, args.validation_scheduler)
164 | pipeline.scheduler = scheduler_class.from_config(pipeline.scheduler.config, **scheduler_args)
165 | pipeline = pipeline.to(accelerator.device)
166 | pipeline.set_progress_bar_config(disable=True)
167 |
168 | if args.pre_compute_text_embeddings:
169 | pipeline_args = {
170 | "prompt_embeds": prompt_embeds,
171 | "negative_prompt_embeds": negative_prompt_embeds,
172 | }
173 | else:
174 | pipeline_args = {"prompt": args.validation_prompt}
175 |
176 | # run inference
177 | generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
178 | images = []
179 | if args.validation_images is None:
180 | for _ in range(args.num_validation_images):
181 | with torch.autocast("cuda"):
182 | image = pipeline(**pipeline_args, num_inference_steps=25, generator=generator).images[0]
183 | images.append(image)
184 | else:
185 | for image in args.validation_images:
186 | image = Image.open(image)
187 | image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
188 | images.append(image)
189 |
190 | for tracker in accelerator.trackers:
191 | if tracker.name == "tensorboard":
192 | np_images = np.stack([np.asarray(img) for img in images])
193 | tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
194 | if tracker.name == "wandb":
195 | tracker.log(
196 | {
197 | "validation": [
198 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
199 | ]
200 | }
201 | )
202 |
203 | del pipeline
204 | torch.cuda.empty_cache()
205 |
206 | return images
207 |
208 |
209 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
210 | text_encoder_config = PretrainedConfig.from_pretrained(
211 | pretrained_model_name_or_path,
212 | subfolder="text_encoder",
213 | revision=revision,
214 | )
215 | model_class = text_encoder_config.architectures[0]
216 |
217 | if model_class == "CLIPTextModel":
218 | from transformers import CLIPTextModel
219 |
220 | return CLIPTextModel
221 | elif model_class == "RobertaSeriesModelWithTransformation":
222 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
223 |
224 | return RobertaSeriesModelWithTransformation
225 | elif model_class == "T5EncoderModel":
226 | from transformers import T5EncoderModel
227 |
228 | return T5EncoderModel
229 | else:
230 | raise ValueError(f"{model_class} is not supported.")
231 |
232 |
233 | def parse_args(input_args=None):
234 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
235 | parser.add_argument(
236 | "--pretrained_model_name_or_path",
237 | type=str,
238 | default=None,
239 | required=True,
240 | help="Path to pretrained model or model identifier from huggingface.co/models.",
241 | )
242 | parser.add_argument(
243 | "--clip_model_name_or_path",
244 | type=str,
245 | default=None,
246 | required=True,
247 | help="Path to clip model or model identifier from huggingface.co/models.",
248 | )
249 | parser.add_argument(
250 | "--revision",
251 | type=str,
252 | default=None,
253 | required=False,
254 | help="Revision of pretrained model identifier from huggingface.co/models.",
255 | )
256 | parser.add_argument(
257 | "--variant",
258 | type=str,
259 | default=None,
260 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
261 | )
262 | parser.add_argument(
263 | "--tokenizer_name",
264 | type=str,
265 | default=None,
266 | help="Pretrained tokenizer name or path if not the same as model_name",
267 | )
268 | parser.add_argument(
269 | "--instance_data_dir",
270 | type=str,
271 | default=None,
272 | required=True,
273 | help="A folder containing the training data of instance images.",
274 | )
275 | parser.add_argument(
276 | "--class_data_dir",
277 | type=str,
278 | default=None,
279 | required=False,
280 | help="A folder containing the training data of class images.",
281 | )
282 | parser.add_argument(
283 | "--instance_prompt",
284 | type=str,
285 | default=None,
286 | required=True,
287 | help="The prompt with identifier specifying the instance",
288 | )
289 | parser.add_argument(
290 | "--class_prompt",
291 | type=str,
292 | default=None,
293 | help="The prompt to specify images in the same class as provided instance images.",
294 | )
295 | parser.add_argument(
296 | "--class_prompts_json_path",
297 | type=str,
298 | default=None,
299 | )
300 | parser.add_argument(
301 | "--with_prior_preservation",
302 | default=False,
303 | action="store_true",
304 | help="Flag to add prior preservation loss.",
305 | )
306 | parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
307 | parser.add_argument(
308 | "--num_class_images",
309 | type=int,
310 | default=100,
311 | help=(
312 | "Minimal class images for prior preservation loss. If there are not enough images already present in"
313 | " class_data_dir, additional images will be sampled with class_prompt."
314 | ),
315 | )
316 | parser.add_argument(
317 | "--output_dir",
318 | type=str,
319 | default="dreambooth-model",
320 | help="The output directory where the model predictions and checkpoints will be written.",
321 | )
322 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
323 | parser.add_argument(
324 | "--resolution",
325 | type=int,
326 | default=512,
327 | help=(
328 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
329 | " resolution"
330 | ),
331 | )
332 | parser.add_argument(
333 | "--center_crop",
334 | default=False,
335 | action="store_true",
336 | help=(
337 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
338 | " cropped. The images will be resized to the resolution first before cropping."
339 | ),
340 | )
341 | parser.add_argument(
342 | "--train_text_encoder",
343 | action="store_true",
344 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
345 | )
346 | parser.add_argument(
347 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
348 | )
349 | parser.add_argument(
350 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
351 | )
352 | parser.add_argument("--num_train_epochs", type=int, default=1)
353 | parser.add_argument(
354 | "--max_train_steps",
355 | type=int,
356 | default=None,
357 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
358 | )
359 | parser.add_argument(
360 | "--checkpointing_steps",
361 | type=int,
362 | default=500,
363 | help=(
364 | "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
365 | "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
366 | "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
367 | "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
368 | "instructions."
369 | ),
370 | )
371 | parser.add_argument(
372 | "--checkpoints_total_limit",
373 | type=int,
374 | default=None,
375 | help=(
376 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
377 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
378 | " for more details"
379 | ),
380 | )
381 | parser.add_argument(
382 | "--resume_from_checkpoint",
383 | type=str,
384 | default=None,
385 | help=(
386 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
387 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
388 | ),
389 | )
390 | parser.add_argument(
391 | "--gradient_accumulation_steps",
392 | type=int,
393 | default=1,
394 | help="Number of updates steps to accumulate before performing a backward/update pass.",
395 | )
396 | parser.add_argument(
397 | "--gradient_checkpointing",
398 | action="store_true",
399 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
400 | )
401 | parser.add_argument(
402 | "--learning_rate",
403 | type=float,
404 | default=5e-6,
405 | help="Initial learning rate (after the potential warmup period) to use.",
406 | )
407 | parser.add_argument(
408 | "--scale_lr",
409 | action="store_true",
410 | default=False,
411 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
412 | )
413 | parser.add_argument(
414 | "--lr_scheduler",
415 | type=str,
416 | default="constant",
417 | help=(
418 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
419 | ' "constant", "constant_with_warmup"]'
420 | ),
421 | )
422 | parser.add_argument(
423 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
424 | )
425 | parser.add_argument(
426 | "--lr_num_cycles",
427 | type=int,
428 | default=1,
429 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
430 | )
431 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
432 | parser.add_argument(
433 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
434 | )
435 | parser.add_argument(
436 | "--dataloader_num_workers",
437 | type=int,
438 | default=0,
439 | help=(
440 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
441 | ),
442 | )
443 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
444 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
445 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
446 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
447 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
448 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
449 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
450 | parser.add_argument(
451 | "--hub_model_id",
452 | type=str,
453 | default=None,
454 | help="The name of the repository to keep in sync with the local `output_dir`.",
455 | )
456 | parser.add_argument(
457 | "--logging_dir",
458 | type=str,
459 | default="logs",
460 | help=(
461 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
462 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
463 | ),
464 | )
465 | parser.add_argument(
466 | "--allow_tf32",
467 | action="store_true",
468 | help=(
469 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
470 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
471 | ),
472 | )
473 | parser.add_argument(
474 | "--report_to",
475 | type=str,
476 | default="tensorboard",
477 | help=(
478 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
479 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
480 | ),
481 | )
482 | parser.add_argument(
483 | "--validation_prompt",
484 | type=str,
485 | default=None,
486 | help="A prompt that is used during validation to verify that the model is learning.",
487 | )
488 | parser.add_argument(
489 | "--validation_prompts_path",
490 | type=str,
491 | default=None
492 | )
493 | parser.add_argument(
494 | "--validation_batch_size",
495 | type=int
496 | )
497 | parser.add_argument(
498 | "--validation_output_dir",
499 | type=str,
500 | default=None,
501 | help="The output directory where images will be written.",
502 | )
503 | parser.add_argument(
504 | "--num_prompts_per_group",
505 | type=int
506 | )
507 | parser.add_argument(
508 | "--num_groups",
509 | type=int
510 | )
511 | parser.add_argument(
512 | "--num_images_per_prompt",
513 | type=int
514 | )
515 | parser.add_argument(
516 | "--num_validation_images",
517 | type=int,
518 | default=4,
519 | help="Number of images that should be generated during validation with `validation_prompt`.",
520 | )
521 | parser.add_argument(
522 | "--validation_steps",
523 | type=int,
524 | default=100,
525 | help=(
526 | "Run validation every X steps. Validation consists of running the prompt"
527 | " `args.validation_prompt` multiple times: `args.num_validation_images`"
528 | " and logging the images."
529 | ),
530 | )
531 | parser.add_argument(
532 | "--mixed_precision",
533 | type=str,
534 | default=None,
535 | choices=["no", "fp16", "bf16"],
536 | help=(
537 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
538 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
539 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
540 | ),
541 | )
542 | parser.add_argument(
543 | "--prior_generation_precision",
544 | type=str,
545 | default=None,
546 | choices=["no", "fp32", "fp16", "bf16"],
547 | help=(
548 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
549 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
550 | ),
551 | )
552 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
553 | parser.add_argument(
554 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
555 | )
556 | parser.add_argument(
557 | "--set_grads_to_none",
558 | action="store_true",
559 | help=(
560 | "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
561 | " behaviors, so disable this argument if it causes any problems. More info:"
562 | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
563 | ),
564 | )
565 |
566 | parser.add_argument(
567 | "--offset_noise",
568 | action="store_true",
569 | default=False,
570 | help=(
571 | "Fine-tuning against a modified noise"
572 | " See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
573 | ),
574 | )
575 | parser.add_argument(
576 | "--snr_gamma",
577 | type=float,
578 | default=None,
579 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
580 | "More details here: https://arxiv.org/abs/2303.09556.",
581 | )
582 | parser.add_argument(
583 | "--pre_compute_text_embeddings",
584 | action="store_true",
585 | help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
586 | )
587 | parser.add_argument(
588 | "--tokenizer_max_length",
589 | type=int,
590 | default=None,
591 | required=False,
592 | help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
593 | )
594 | parser.add_argument(
595 | "--text_encoder_use_attention_mask",
596 | action="store_true",
597 | required=False,
598 | help="Whether to use attention mask for the text encoder",
599 | )
600 | parser.add_argument(
601 | "--skip_save_text_encoder", action="store_true", required=False, help="Set to not save text encoder"
602 | )
603 | parser.add_argument(
604 | "--validation_images",
605 | required=False,
606 | default=None,
607 | nargs="+",
608 | help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
609 | )
610 | parser.add_argument(
611 | "--class_labels_conditioning",
612 | required=False,
613 | default=None,
614 | help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
615 | )
616 | parser.add_argument(
617 | "--validation_scheduler",
618 | type=str,
619 | default="DPMSolverMultistepScheduler",
620 | choices=["DPMSolverMultistepScheduler", "DDPMScheduler"],
621 | help="Select which scheduler to use for validation. DDPMScheduler is recommended for DeepFloyd IF.",
622 | )
623 |
624 | if input_args is not None:
625 | args = parser.parse_args(input_args)
626 | else:
627 | args = parser.parse_args()
628 |
629 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
630 | if env_local_rank != -1 and env_local_rank != args.local_rank:
631 | args.local_rank = env_local_rank
632 |
633 | if args.with_prior_preservation:
634 | if args.class_data_dir is None:
635 | raise ValueError("You must specify a data directory for class images.")
636 | if args.class_prompt is None and args.class_prompts_json_path is None:
637 | raise ValueError("You must specify prompt for class images.")
638 | else:
639 | # logger is not available yet
640 | if args.class_data_dir is not None:
641 | warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
642 | if args.class_prompt is not None:
643 | warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
644 |
645 | if args.train_text_encoder and args.pre_compute_text_embeddings:
646 | raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
647 |
648 | return args
649 |
650 |
651 | class DreamBoothDataset(Dataset):
652 | """
653 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
654 | It pre-processes the images and the tokenizes prompts.
655 | """
656 |
657 | def __init__(
658 | self,
659 | instance_data_root,
660 | instance_prompt,
661 | tokenizer,
662 | class_data_root=None,
663 | class_prompt=None,
664 | class_num=None,
665 | size=512,
666 | center_crop=False,
667 | encoder_hidden_states=None,
668 | class_prompt_encoder_hidden_states=None,
669 | tokenizer_max_length=None,
670 | class_prompts_json_path=None
671 | ):
672 | self.size = size
673 | self.center_crop = center_crop
674 | self.tokenizer = tokenizer
675 | self.encoder_hidden_states = encoder_hidden_states
676 | self.class_prompt_encoder_hidden_states = class_prompt_encoder_hidden_states
677 | self.tokenizer_max_length = tokenizer_max_length
678 |
679 | self.instance_data_root = Path(instance_data_root)
680 | if not self.instance_data_root.exists():
681 | raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")
682 |
683 | self.instance_images_path = [p for p in Path(instance_data_root).iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}]
684 | self.num_instance_images = len(self.instance_images_path)
685 | self.instance_prompt = instance_prompt
686 | self._length = self.num_instance_images
687 |
688 | if class_data_root is not None:
689 | self.class_data_root = Path(class_data_root)
690 | self.class_data_root.mkdir(parents=True, exist_ok=True)
691 | self.class_images_path = [p for p in self.class_data_root.iterdir() if p.suffix.lower() in {'.jpg', '.jpeg', '.png'}]
692 | if class_num is not None:
693 | self.num_class_images = min(len(self.class_images_path), class_num)
694 | else:
695 | self.num_class_images = len(self.class_images_path)
696 | self._length = max(self.num_class_images, self.num_instance_images)
697 | self.class_prompts = [class_prompt]
698 | if class_prompts_json_path:
699 | self.class_images_path.sort(key=lambda x:int(re.findall(r'\d+',x.name)[-1]))
700 | self.class_prompts = json.load(open(class_prompts_json_path,'r',encoding='utf-8'))
701 | self.num_class_prompts = len(self.class_prompts)
702 | else:
703 | self.class_data_root = None
704 |
705 | self.image_transforms = transforms.Compose(
706 | [
707 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
708 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
709 | transforms.ToTensor(),
710 | transforms.Normalize([0.5], [0.5]),
711 | ]
712 | )
713 |
714 | def __len__(self):
715 | return self._length
716 |
717 | def __getitem__(self, index):
718 | example = {}
719 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
720 | instance_image = exif_transpose(instance_image)
721 |
722 | if not instance_image.mode == "RGB":
723 | instance_image = instance_image.convert("RGB")
724 | example["instance_images"] = self.image_transforms(instance_image)
725 |
726 | if self.encoder_hidden_states is not None:
727 | example["instance_prompt_ids"] = self.encoder_hidden_states
728 | else:
729 | text_inputs = tokenize_prompt(
730 | self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
731 | )
732 | example["instance_prompt_ids"] = text_inputs.input_ids
733 | example["instance_attention_mask"] = text_inputs.attention_mask
734 |
735 | if self.class_data_root:
736 | class_image = Image.open(self.class_images_path[index % self.num_class_images])
737 | class_image = exif_transpose(class_image)
738 | class_prompt = self.class_prompts[index % self.num_class_images % self.num_class_prompts]
739 |
740 | if not class_image.mode == "RGB":
741 | class_image = class_image.convert("RGB")
742 | example["class_images"] = self.image_transforms(class_image)
743 |
744 | if self.class_prompt_encoder_hidden_states is not None:
745 | example["class_prompt_ids"] = self.class_prompt_encoder_hidden_states
746 | else:
747 | class_text_inputs = tokenize_prompt(
748 | self.tokenizer, class_prompt, tokenizer_max_length=self.tokenizer_max_length
749 | )
750 | example["class_prompt_ids"] = class_text_inputs.input_ids
751 | example["class_attention_mask"] = class_text_inputs.attention_mask
752 |
753 | return example
754 |
755 |
756 | def collate_fn(examples, with_prior_preservation=False):
757 | has_attention_mask = "instance_attention_mask" in examples[0]
758 |
759 | input_ids = [example["instance_prompt_ids"] for example in examples]
760 | pixel_values = [example["instance_images"] for example in examples]
761 |
762 | if has_attention_mask:
763 | attention_mask = [example["instance_attention_mask"] for example in examples]
764 |
765 | # Concat class and instance examples for prior preservation.
766 | # We do this to avoid doing two forward passes.
767 | if with_prior_preservation:
768 | input_ids += [example["class_prompt_ids"] for example in examples]
769 | pixel_values += [example["class_images"] for example in examples]
770 |
771 | if has_attention_mask:
772 | attention_mask += [example["class_attention_mask"] for example in examples]
773 |
774 | pixel_values = torch.stack(pixel_values)
775 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
776 |
777 | input_ids = torch.cat(input_ids, dim=0)
778 |
779 | batch = {
780 | "input_ids": input_ids,
781 | "pixel_values": pixel_values,
782 | }
783 |
784 | if has_attention_mask:
785 | attention_mask = torch.cat(attention_mask, dim=0)
786 | batch["attention_mask"] = attention_mask
787 |
788 | return batch
789 |
790 |
791 | class PromptDataset(Dataset):
792 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
793 |
794 | def __init__(self, prompt, num_samples):
795 | self.prompt = prompt
796 | self.num_samples = num_samples
797 |
798 | def __len__(self):
799 | return self.num_samples
800 |
801 | def __getitem__(self, index):
802 | example = {}
803 | example["prompt"] = self.prompt
804 | example["index"] = index
805 | return example
806 |
807 |
808 | def model_has_vae(args):
809 | config_file_name = os.path.join("vae", AutoencoderKL.config_name)
810 | if os.path.isdir(args.pretrained_model_name_or_path):
811 | config_file_name = os.path.join(args.pretrained_model_name_or_path, config_file_name)
812 | return os.path.isfile(config_file_name)
813 | else:
814 | files_in_repo = model_info(args.pretrained_model_name_or_path, revision=args.revision).siblings
815 | return any(file.rfilename == config_file_name for file in files_in_repo)
816 |
817 |
818 | def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
819 | if tokenizer_max_length is not None:
820 | max_length = tokenizer_max_length
821 | else:
822 | max_length = tokenizer.model_max_length
823 |
824 | text_inputs = tokenizer(
825 | prompt,
826 | truncation=True,
827 | padding="max_length",
828 | max_length=max_length,
829 | return_tensors="pt",
830 | )
831 |
832 | return text_inputs
833 |
834 |
835 | def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
836 | text_input_ids = input_ids.to(text_encoder.device)
837 |
838 | if text_encoder_use_attention_mask:
839 | attention_mask = attention_mask.to(text_encoder.device)
840 | else:
841 | attention_mask = None
842 |
843 | prompt_embeds = text_encoder(
844 | text_input_ids,
845 | attention_mask=attention_mask,
846 | )
847 | prompt_embeds = prompt_embeds[0]
848 |
849 | return prompt_embeds
850 |
851 |
852 | def main(args):
853 | logging_dir = Path(args.output_dir, args.logging_dir)
854 |
855 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
856 |
857 | accelerator = Accelerator(
858 | gradient_accumulation_steps=args.gradient_accumulation_steps,
859 | mixed_precision=args.mixed_precision,
860 | log_with=args.report_to,
861 | project_config=accelerator_project_config,
862 | )
863 |
864 | if args.report_to == "wandb":
865 | if not is_wandb_available():
866 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
867 |
868 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
869 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
870 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
871 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
872 | raise ValueError(
873 | "Gradient accumulation is not supported when training the text encoder in distributed training. "
874 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
875 | )
876 |
877 | # Make one log on every process with the configuration for debugging.
878 | logging.basicConfig(
879 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
880 | datefmt="%m/%d/%Y %H:%M:%S",
881 | level=logging.INFO,
882 | )
883 | logger.info(accelerator.state, main_process_only=False)
884 | if accelerator.is_local_main_process:
885 | transformers.utils.logging.set_verbosity_warning()
886 | diffusers.utils.logging.set_verbosity_info()
887 | else:
888 | transformers.utils.logging.set_verbosity_error()
889 | diffusers.utils.logging.set_verbosity_error()
890 |
891 | # If passed along, set the training seed now.
892 | if args.seed is not None:
893 | set_seed(args.seed)
894 |
895 | # Generate class images if prior preservation is enabled.
896 | if args.with_prior_preservation:
897 | class_images_dir = Path(args.class_data_dir)
898 | if not class_images_dir.exists():
899 | class_images_dir.mkdir(parents=True)
900 | cur_class_images = len(list(class_images_dir.iterdir()))
901 |
902 | if cur_class_images < args.num_class_images:
903 | torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
904 | if args.prior_generation_precision == "fp32":
905 | torch_dtype = torch.float32
906 | elif args.prior_generation_precision == "fp16":
907 | torch_dtype = torch.float16
908 | elif args.prior_generation_precision == "bf16":
909 | torch_dtype = torch.bfloat16
910 | pipeline = DiffusionPipeline.from_pretrained(
911 | args.pretrained_model_name_or_path,
912 | torch_dtype=torch_dtype,
913 | safety_checker=None,
914 | revision=args.revision,
915 | variant=args.variant,
916 | )
917 | pipeline.set_progress_bar_config(disable=True)
918 |
919 | num_new_images = args.num_class_images - cur_class_images
920 | logger.info(f"Number of class images to sample: {num_new_images}.")
921 |
922 | sample_dataset = PromptDataset(args.class_prompt, num_new_images)
923 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
924 |
925 | sample_dataloader = accelerator.prepare(sample_dataloader)
926 | pipeline.to(accelerator.device)
927 |
928 | for example in tqdm(
929 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
930 | ):
931 | images = pipeline(example["prompt"]).images
932 |
933 | for i, image in enumerate(images):
934 | hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
935 | image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
936 | image.save(image_filename)
937 |
938 | del pipeline
939 | if torch.cuda.is_available():
940 | torch.cuda.empty_cache()
941 |
942 | # Handle the repository creation
943 | if accelerator.is_main_process:
944 | if args.output_dir is not None:
945 | os.makedirs(args.output_dir, exist_ok=True)
946 |
947 | if args.push_to_hub:
948 | repo_id = create_repo(
949 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
950 | ).repo_id
951 |
952 | # Load the tokenizer
953 | if args.tokenizer_name:
954 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
955 | elif args.pretrained_model_name_or_path:
956 | tokenizer = AutoTokenizer.from_pretrained(
957 | args.pretrained_model_name_or_path,
958 | subfolder="tokenizer",
959 | revision=args.revision,
960 | use_fast=False,
961 | )
962 |
963 | # import correct text encoder class
964 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
965 |
966 | # Load scheduler and models
967 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
968 | text_encoder = text_encoder_cls.from_pretrained(
969 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
970 | )
971 |
972 | if args.clip_model_name_or_path:
973 | clip_model = CLIPModel.from_pretrained(args.clip_model_name_or_path)
974 | clip_processor = CLIPProcessor.from_pretrained(args.clip_model_name_or_path)
975 |
976 | if model_has_vae(args):
977 | vae = AutoencoderKL.from_pretrained(
978 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
979 | )
980 | else:
981 | vae = None
982 |
983 | unet = UNet2DConditionModel.from_pretrained(
984 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
985 | )
986 |
987 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
988 | def save_model_hook(models, weights, output_dir):
989 | if accelerator.is_main_process:
990 | for model in models:
991 | sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
992 | model.save_pretrained(os.path.join(output_dir, sub_dir))
993 |
994 | # make sure to pop weight so that corresponding model is not saved again
995 | weights.pop()
996 |
997 | def load_model_hook(models, input_dir):
998 | while len(models) > 0:
999 | # pop models so that they are not loaded again
1000 | model = models.pop()
1001 |
1002 | if isinstance(model, type(accelerator.unwrap_model(text_encoder))):
1003 | # load transformers style into model
1004 | load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
1005 | model.config = load_model.config
1006 | else:
1007 | # load diffusers style into model
1008 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
1009 | model.register_to_config(**load_model.config)
1010 |
1011 | model.load_state_dict(load_model.state_dict())
1012 | del load_model
1013 |
1014 | accelerator.register_save_state_pre_hook(save_model_hook)
1015 | accelerator.register_load_state_pre_hook(load_model_hook)
1016 |
1017 | if vae is not None:
1018 | vae.requires_grad_(False)
1019 |
1020 | if not args.train_text_encoder:
1021 | text_encoder.requires_grad_(False)
1022 |
1023 | if args.enable_xformers_memory_efficient_attention:
1024 | if is_xformers_available():
1025 | import xformers
1026 |
1027 | xformers_version = version.parse(xformers.__version__)
1028 | if xformers_version == version.parse("0.0.16"):
1029 | logger.warn(
1030 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
1031 | )
1032 | unet.enable_xformers_memory_efficient_attention()
1033 | else:
1034 | raise ValueError("xformers is not available. Make sure it is installed correctly")
1035 |
1036 | if args.gradient_checkpointing:
1037 | unet.enable_gradient_checkpointing()
1038 | if args.train_text_encoder:
1039 | text_encoder.gradient_checkpointing_enable()
1040 |
1041 | # Check that all trainable models are in full precision
1042 | low_precision_error_string = (
1043 | "Please make sure to always have all model weights in full float32 precision when starting training - even if"
1044 | " doing mixed precision training. copy of the weights should still be float32."
1045 | )
1046 |
1047 | if accelerator.unwrap_model(unet).dtype != torch.float32:
1048 | raise ValueError(
1049 | f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
1050 | )
1051 |
1052 | if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
1053 | raise ValueError(
1054 | f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
1055 | f" {low_precision_error_string}"
1056 | )
1057 |
1058 | # Enable TF32 for faster training on Ampere GPUs,
1059 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1060 | if args.allow_tf32:
1061 | torch.backends.cuda.matmul.allow_tf32 = True
1062 |
1063 | if args.scale_lr:
1064 | args.learning_rate = (
1065 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1066 | )
1067 |
1068 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1069 | if args.use_8bit_adam:
1070 | try:
1071 | import bitsandbytes as bnb
1072 | except ImportError:
1073 | raise ImportError(
1074 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1075 | )
1076 |
1077 | optimizer_class = bnb.optim.AdamW8bit
1078 | else:
1079 | optimizer_class = torch.optim.AdamW
1080 |
1081 | # Optimizer creation
1082 | params_to_optimize = (
1083 | itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
1084 | )
1085 | optimizer = optimizer_class(
1086 | params_to_optimize,
1087 | lr=args.learning_rate,
1088 | betas=(args.adam_beta1, args.adam_beta2),
1089 | weight_decay=args.adam_weight_decay,
1090 | eps=args.adam_epsilon,
1091 | )
1092 |
1093 | if args.pre_compute_text_embeddings:
1094 |
1095 | def compute_text_embeddings(prompt):
1096 | with torch.no_grad():
1097 | text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
1098 | prompt_embeds = encode_prompt(
1099 | text_encoder,
1100 | text_inputs.input_ids,
1101 | text_inputs.attention_mask,
1102 | text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
1103 | )
1104 |
1105 | return prompt_embeds
1106 |
1107 | pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
1108 | validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
1109 |
1110 | if args.validation_prompt is not None:
1111 | validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
1112 | else:
1113 | validation_prompt_encoder_hidden_states = None
1114 |
1115 | if args.class_prompt is not None:
1116 | pre_computed_class_prompt_encoder_hidden_states = compute_text_embeddings(args.class_prompt)
1117 | else:
1118 | pre_computed_class_prompt_encoder_hidden_states = None
1119 |
1120 | text_encoder = None
1121 | tokenizer = None
1122 |
1123 | gc.collect()
1124 | torch.cuda.empty_cache()
1125 | else:
1126 | pre_computed_encoder_hidden_states = None
1127 | validation_prompt_encoder_hidden_states = None
1128 | validation_prompt_negative_prompt_embeds = None
1129 | pre_computed_class_prompt_encoder_hidden_states = None
1130 |
1131 | # Dataset and DataLoaders creation:
1132 | train_dataset = DreamBoothDataset(
1133 | instance_data_root=args.instance_data_dir,
1134 | instance_prompt=args.instance_prompt,
1135 | class_data_root=args.class_data_dir if args.with_prior_preservation else None,
1136 | class_prompt=args.class_prompt,
1137 | class_num=args.num_class_images,
1138 | tokenizer=tokenizer,
1139 | size=args.resolution,
1140 | center_crop=args.center_crop,
1141 | encoder_hidden_states=pre_computed_encoder_hidden_states,
1142 | class_prompt_encoder_hidden_states=pre_computed_class_prompt_encoder_hidden_states,
1143 | tokenizer_max_length=args.tokenizer_max_length,
1144 | class_prompts_json_path=args.class_prompts_json_path
1145 | )
1146 |
1147 | train_dataloader = torch.utils.data.DataLoader(
1148 | train_dataset,
1149 | batch_size=args.train_batch_size,
1150 | shuffle=True,
1151 | collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
1152 | num_workers=args.dataloader_num_workers,
1153 | )
1154 |
1155 | # Scheduler and math around the number of training steps.
1156 | overrode_max_train_steps = False
1157 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1158 | if args.max_train_steps is None:
1159 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1160 | overrode_max_train_steps = True
1161 |
1162 | lr_scheduler = get_scheduler(
1163 | args.lr_scheduler,
1164 | optimizer=optimizer,
1165 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1166 | num_training_steps=args.max_train_steps * accelerator.num_processes,
1167 | num_cycles=args.lr_num_cycles,
1168 | power=args.lr_power,
1169 | )
1170 |
1171 | # Prepare everything with our `accelerator`.
1172 | if args.train_text_encoder:
1173 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1174 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler
1175 | )
1176 | else:
1177 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1178 | unet, optimizer, train_dataloader, lr_scheduler
1179 | )
1180 |
1181 | # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
1182 | # as these weights are only used for inference, keeping weights in full precision is not required.
1183 | weight_dtype = torch.float32
1184 | if accelerator.mixed_precision == "fp16":
1185 | weight_dtype = torch.float16
1186 | elif accelerator.mixed_precision == "bf16":
1187 | weight_dtype = torch.bfloat16
1188 |
1189 | # Move vae and text_encoder to device and cast to weight_dtype
1190 | if vae is not None:
1191 | vae.to(accelerator.device, dtype=weight_dtype)
1192 |
1193 | if not args.train_text_encoder and text_encoder is not None:
1194 | text_encoder.to(accelerator.device, dtype=weight_dtype)
1195 |
1196 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1197 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1198 | if overrode_max_train_steps:
1199 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1200 | # Afterwards we recalculate our number of training epochs
1201 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1202 |
1203 | # We need to initialize the trackers we use, and also store our configuration.
1204 | # The trackers initializes automatically on the main process.
1205 | if accelerator.is_main_process:
1206 | tracker_config = vars(copy.deepcopy(args))
1207 | tracker_config.pop("validation_images")
1208 | accelerator.init_trackers("dreambooth", config=tracker_config)
1209 |
1210 | # Train!
1211 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1212 |
1213 | logger.info("***** Running training *****")
1214 | logger.info(f" Num examples = {len(train_dataset)}")
1215 | logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1216 | logger.info(f" Num Epochs = {args.num_train_epochs}")
1217 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1218 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1219 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1220 | logger.info(f" Total optimization steps = {args.max_train_steps}")
1221 | global_step = 0
1222 | first_epoch = 0
1223 |
1224 | # Potentially load in the weights and states from a previous save
1225 | if args.resume_from_checkpoint:
1226 | if args.resume_from_checkpoint != "latest":
1227 | path = os.path.basename(args.resume_from_checkpoint)
1228 | else:
1229 | # Get the most recent checkpoint
1230 | dirs = os.listdir(args.output_dir)
1231 | dirs = [d for d in dirs if d.startswith("checkpoint")]
1232 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1233 | path = dirs[-1] if len(dirs) > 0 else None
1234 |
1235 | if path is None:
1236 | accelerator.print(
1237 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1238 | )
1239 | args.resume_from_checkpoint = None
1240 | initial_global_step = 0
1241 | else:
1242 | accelerator.print(f"Resuming from checkpoint {path}")
1243 | accelerator.load_state(os.path.join(args.output_dir, path))
1244 | global_step = int(path.split("-")[1])
1245 |
1246 | initial_global_step = global_step
1247 | first_epoch = global_step // num_update_steps_per_epoch
1248 | else:
1249 | initial_global_step = 0
1250 |
1251 | progress_bar = tqdm(
1252 | range(0, args.max_train_steps),
1253 | initial=initial_global_step,
1254 | desc="Steps",
1255 | # Only show the progress bar once on each machine.
1256 | disable=not accelerator.is_local_main_process,
1257 | )
1258 |
1259 | for epoch in range(first_epoch, args.num_train_epochs):
1260 | unet.train()
1261 | if args.train_text_encoder:
1262 | text_encoder.train()
1263 | for step, batch in enumerate(train_dataloader):
1264 | with accelerator.accumulate(unet):
1265 | pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
1266 |
1267 | if vae is not None:
1268 | # Convert images to latent space
1269 | model_input = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
1270 | model_input = model_input * vae.config.scaling_factor
1271 | else:
1272 | model_input = pixel_values
1273 |
1274 | # Sample noise that we'll add to the model input
1275 | if args.offset_noise:
1276 | noise = torch.randn_like(model_input) + 0.1 * torch.randn(
1277 | model_input.shape[0], model_input.shape[1], 1, 1, device=model_input.device
1278 | )
1279 | else:
1280 | noise = torch.randn_like(model_input)
1281 | bsz, channels, height, width = model_input.shape
1282 | # Sample a random timestep for each image
1283 | timesteps = torch.randint(
1284 | 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1285 | )
1286 | timesteps = timesteps.long()
1287 |
1288 | # Add noise to the model input according to the noise magnitude at each timestep
1289 | # (this is the forward diffusion process)
1290 | noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1291 |
1292 | # Get the text embedding for conditioning
1293 | if args.pre_compute_text_embeddings:
1294 | encoder_hidden_states = batch["input_ids"]
1295 | else:
1296 | encoder_hidden_states = encode_prompt(
1297 | text_encoder,
1298 | batch["input_ids"],
1299 | batch["attention_mask"],
1300 | text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
1301 | )
1302 |
1303 | if accelerator.unwrap_model(unet).config.in_channels == channels * 2:
1304 | noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
1305 |
1306 | if args.class_labels_conditioning == "timesteps":
1307 | class_labels = timesteps
1308 | else:
1309 | class_labels = None
1310 |
1311 | # Predict the noise residual
1312 | model_pred = unet(
1313 | noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
1314 | ).sample
1315 |
1316 | if model_pred.shape[1] == 6:
1317 | model_pred, _ = torch.chunk(model_pred, 2, dim=1)
1318 |
1319 | # Get the target for loss depending on the prediction type
1320 | if noise_scheduler.config.prediction_type == "epsilon":
1321 | target = noise
1322 | elif noise_scheduler.config.prediction_type == "v_prediction":
1323 | target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1324 | else:
1325 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1326 |
1327 | if args.with_prior_preservation:
1328 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1329 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1330 | target, target_prior = torch.chunk(target, 2, dim=0)
1331 | # Compute prior loss
1332 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
1333 |
1334 | # Compute instance loss
1335 | if args.snr_gamma is None:
1336 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1337 | else:
1338 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1339 | # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1340 | # This is discussed in Section 4.2 of the same paper.
1341 | snr = compute_snr(noise_scheduler, timesteps)
1342 | base_weight = (
1343 | torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1344 | )
1345 |
1346 | if noise_scheduler.config.prediction_type == "v_prediction":
1347 | # Velocity objective needs to be floored to an SNR weight of one.
1348 | mse_loss_weights = base_weight + 1
1349 | else:
1350 | # Epsilon and sample both use the same loss weights.
1351 | mse_loss_weights = base_weight
1352 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1353 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1354 | loss = loss.mean()
1355 |
1356 | if args.with_prior_preservation:
1357 | # Add the prior loss to the instance loss.
1358 | loss = loss + args.prior_loss_weight * prior_loss
1359 |
1360 | accelerator.backward(loss)
1361 | if accelerator.sync_gradients:
1362 | params_to_clip = (
1363 | itertools.chain(unet.parameters(), text_encoder.parameters())
1364 | if args.train_text_encoder
1365 | else unet.parameters()
1366 | )
1367 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1368 | optimizer.step()
1369 | lr_scheduler.step()
1370 | optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1371 |
1372 | # Checks if the accelerator has performed an optimization step behind the scenes
1373 | if accelerator.sync_gradients:
1374 | progress_bar.update(1)
1375 | global_step += 1
1376 |
1377 | if accelerator.is_main_process:
1378 | if global_step % args.checkpointing_steps == 0:
1379 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1380 | if args.checkpoints_total_limit is not None:
1381 | checkpoints = os.listdir(args.output_dir)
1382 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1383 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1384 |
1385 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1386 | if len(checkpoints) >= args.checkpoints_total_limit:
1387 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1388 | removing_checkpoints = checkpoints[0:num_to_remove]
1389 |
1390 | logger.info(
1391 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1392 | )
1393 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1394 |
1395 | for removing_checkpoint in removing_checkpoints:
1396 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1397 | shutil.rmtree(removing_checkpoint)
1398 |
1399 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1400 | accelerator.save_state(save_path)
1401 | logger.info(f"Saved state to {save_path}")
1402 |
1403 | images = []
1404 |
1405 | if args.validation_prompt is not None and global_step % args.validation_steps == 0:
1406 | images = log_validation(
1407 | text_encoder,
1408 | tokenizer,
1409 | unet,
1410 | vae,
1411 | args,
1412 | accelerator,
1413 | weight_dtype,
1414 | global_step,
1415 | validation_prompt_encoder_hidden_states,
1416 | validation_prompt_negative_prompt_embeds,
1417 | )
1418 |
1419 | if args.validation_prompts_path is not None and global_step % args.validation_steps == 0:
1420 | pipeline = DiffusionPipeline.from_pretrained(
1421 | args.pretrained_model_name_or_path,
1422 | tokenizer=tokenizer,
1423 | text_encoder=accelerator.unwrap_model(text_encoder),
1424 | unet=accelerator.unwrap_model(unet),
1425 | revision=args.revision,
1426 | variant=args.variant,
1427 | torch_dtype=weight_dtype,
1428 | vae=accelerator.unwrap_model(vae),
1429 | safety_checker=None,
1430 | )
1431 | pipeline = pipeline.to(accelerator.device)
1432 | pipeline.set_progress_bar_config(disable=True)
1433 | if args.enable_xformers_memory_efficient_attention:
1434 | pipeline.enable_xformers_memory_efficient_attention()
1435 |
1436 | with open(args.validation_prompts_path,'r',encoding='utf-8') as f:
1437 | prompts = json.load(f)[:25]
1438 |
1439 | score = get_score(pipeline, clip_model, clip_processor, prompts, args.validation_batch_size, f'{args.validation_output_dir}/{global_step}')
1440 | print(f'{global_step}: {score}')
1441 |
1442 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1443 | progress_bar.set_postfix(**logs)
1444 | accelerator.log(logs, step=global_step)
1445 |
1446 | if global_step >= args.max_train_steps:
1447 | break
1448 |
1449 | # Create the pipeline using the trained modules and save it.
1450 | accelerator.wait_for_everyone()
1451 | if accelerator.is_main_process:
1452 | pipeline_args = {}
1453 |
1454 | if text_encoder is not None:
1455 | pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder)
1456 |
1457 | if args.skip_save_text_encoder:
1458 | pipeline_args["text_encoder"] = None
1459 |
1460 | pipeline = DiffusionPipeline.from_pretrained(
1461 | args.pretrained_model_name_or_path,
1462 | unet=accelerator.unwrap_model(unet),
1463 | revision=args.revision,
1464 | variant=args.variant,
1465 | **pipeline_args,
1466 | )
1467 |
1468 | # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1469 | scheduler_args = {}
1470 |
1471 | if "variance_type" in pipeline.scheduler.config:
1472 | variance_type = pipeline.scheduler.config.variance_type
1473 |
1474 | if variance_type in ["learned", "learned_range"]:
1475 | variance_type = "fixed_small"
1476 |
1477 | scheduler_args["variance_type"] = variance_type
1478 |
1479 | pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
1480 |
1481 | pipeline.save_pretrained(args.output_dir)
1482 |
1483 | if args.push_to_hub:
1484 | save_model_card(
1485 | repo_id,
1486 | images=images,
1487 | base_model=args.pretrained_model_name_or_path,
1488 | train_text_encoder=args.train_text_encoder,
1489 | prompt=args.instance_prompt,
1490 | repo_folder=args.output_dir,
1491 | pipeline=pipeline,
1492 | )
1493 | upload_folder(
1494 | repo_id=repo_id,
1495 | folder_path=args.output_dir,
1496 | commit_message="End of training",
1497 | ignore_patterns=["step_*", "epoch_*"],
1498 | )
1499 |
1500 | accelerator.end_training()
1501 |
1502 |
1503 | if __name__ == "__main__":
1504 | args = parse_args()
1505 | main(args)
1506 |
--------------------------------------------------------------------------------