├── .gitignore
├── LICENSE
├── README.md
├── assets
├── README.md
├── cat_statue
│ ├── 1.jpeg
│ ├── 2.jpeg
│ ├── 3.jpeg
│ ├── 4.jpeg
│ ├── 6.jpeg
│ └── 7.jpeg
├── mug_skulls
│ ├── 1.jpeg
│ ├── 2.jpeg
│ ├── 3.jpeg
│ └── 4.jpeg
├── outputs
│ ├── ti.png
│ └── xti_v1.png
└── paper.png
├── inference.py
├── prompt_plus
├── __init__.py
├── prompt_plus_pipeline_stable_diffusion.py
└── prompt_plus_unet_2d_condition.py
├── requirements.txt
├── scripts
├── app.py
└── textual_inversion.py
└── train_p_plus.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # build artifacts
2 |
3 | .eggs/
4 | .mypy_cache
5 | *.egg-info/
6 | build/
7 | dist/
8 | pip-wheel-metadata/
9 |
10 |
11 | # dev tools
12 |
13 | .envrc
14 | .python-version
15 | .idea
16 | .venv/
17 | .vscode/
18 | /*.iml
19 |
20 |
21 | # jupyter notebooks
22 |
23 | .ipynb_checkpoints
24 |
25 |
26 | # miscellaneous
27 |
28 | .cache/
29 | doc/_build/
30 | *.swp
31 | .DS_Store
32 |
33 |
34 | # python
35 |
36 | *.pyc
37 | *.pyo
38 | __pycache__
39 |
40 |
41 | # testing and continuous integration
42 |
43 | .coverage
44 | .pytest_cache/
45 | .benchmarks
46 |
47 | # custom
48 | *.ipynb
49 | data
50 | private
51 | wandb
52 | models
53 | *.sh
54 | xti_cat
55 | grid.png
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Makoto Shing
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # prompt-plus-pytorch
2 |
3 |
4 | An implementation of [P+: Extended Textual Conditioning in Text-to-Image Generation](https://prompt-plus.github.io/) by using d🧨ffusers.
5 |
6 | My summary is found [here](https://twitter.com/mk1stats/status/1637785231729262592).
7 |
8 | 
9 |
10 | ## Current Status
11 | I can't still get better results than Textual Inversion.
12 | The hyper-parameters are exactly same as Textual Inversion except the number of training steps as the paper said in section 4.2.2.
13 |
14 | **Textual inversion:**
15 | 
16 | **Extended Textual Inversion:**
17 | 
18 |
19 | Does it mean that we need n-layer x training steps (500) in total? My current implementation is jointly training all embeddings.
20 | > This optimization is applied independently to each cross-attention layer.
21 |
22 | ## Installation
23 | ```commandline
24 | git clone https://github.com/mkshing/prompt-plus-pytorch
25 | pip install -r requirements.txt
26 | ```
27 |
28 | ## Training
29 | ```commandline
30 | accelerate launch train_p_plus.py \
31 | --pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4" \
32 | --train_data_dir="assets/cat_statue" \
33 | --learnable_property="object" \
34 | --placeholder_token="" --initializer_token="toy" \
35 | --resolution=512 \
36 | --train_batch_size=1 \
37 | --gradient_accumulation_steps=8 \
38 | --max_train_steps=500 \
39 | --learning_rate=5.0e-03 \
40 | --lr_scheduler="constant" \
41 | --lr_warmup_steps=0 \
42 | --output_dir="xti_cat" \
43 | --report_to "wandb" \
44 | --only_save_embeds \
45 | --enable_xformers_memory_efficient_attention
46 | ```
47 |
48 | ## Inference
49 |
50 | ```python
51 | from prompt_plus import PPlusStableDiffusionPipeline
52 |
53 | pipe = PPlusStableDiffusionPipeline.from_learned_embed(
54 | pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4",
55 | learned_embed_name_or_path="learned-embed.bin path"
56 | )
57 | prompt = "A backpack"
58 | image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
59 | image.save("cat-backpack.png")
60 | ```
61 | They also proposed "Style Mixing" to combine 2 embeds.
62 | ```python
63 | pipe = PPlusStableDiffusionPipeline.from_learned_embed(
64 | pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4",
65 | learned_embed_name_or_path=["learned-embed 1", "learned-embed 2"],
66 | style_mixing_k_K=(5, 10),
67 | )
68 | ```
69 | Also, I made a pipeline for textual inversion to test easily.
70 | ```python
71 | from prompt_plus import TextualInversionStableDiffusionPipeline
72 |
73 | pipe = TextualInversionStableDiffusionPipeline.from_learned_embed(
74 | pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4",
75 | learned_embed_name_or_path="sd-concepts-library/cat-toy",
76 | )
77 | prompt = "A backpack"
78 | images = pipe(prompt, num_inference_steps=50, guidance_scale=7.5)
79 | ```
80 |
81 | If you want to do inference in commandline,
82 | ```commandline
83 | python inference.py \
84 | --pretrained_model_name_or_path "CompVis/stable-diffusion-v1-4" \
85 | --learned_embed_name_or_path "xti_cat" \
86 | --prompt "A backpack" \
87 | --float16 \
88 | --seed 1000
89 | ```
90 | ## Citation
91 |
92 | ```bibtex
93 | @article{voynov2023P+,
94 | title={P+: Extended Textual Conditioning in Text-to-Image Generation},
95 | author={Voynov, Andrey and Chu, Qinghao and Cohen-Or, Daniel and Aberman, Kfir},
96 | booktitle={arXiv preprint},
97 | year={2023},
98 | url={https://arxiv.org/abs/2303.09522}
99 | }
100 | ```
101 |
102 | ## Reference
103 | - [diffusers Textual Inversion code](https://github.com/huggingface/diffusers/tree/main/examples/textual_inversion)
104 |
105 | ## TODO
106 | - [x] Training
107 | - [x] Inference
108 | - [x] Style Mixing
109 | - [ ] Regularization
--------------------------------------------------------------------------------
/assets/README.md:
--------------------------------------------------------------------------------
1 | `cat_statue` and `mug_skulls` are taken from the [original Textual Inversion repository](https://github.com/rinongal/textual_inversion#pretrained-models--data)
2 | and resized to 512x512 by the following code.
3 | ```python
4 | import os
5 | from PIL import Image
6 |
7 | image_dir = "image-path"
8 | for file_path in os.listdir(image_dir):
9 | image_path = os.path.join(save_path, file_path)
10 | Image.open(image_path).resize((512, 512)).save(image_path)
11 | ```
--------------------------------------------------------------------------------
/assets/cat_statue/1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/cat_statue/1.jpeg
--------------------------------------------------------------------------------
/assets/cat_statue/2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/cat_statue/2.jpeg
--------------------------------------------------------------------------------
/assets/cat_statue/3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/cat_statue/3.jpeg
--------------------------------------------------------------------------------
/assets/cat_statue/4.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/cat_statue/4.jpeg
--------------------------------------------------------------------------------
/assets/cat_statue/6.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/cat_statue/6.jpeg
--------------------------------------------------------------------------------
/assets/cat_statue/7.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/cat_statue/7.jpeg
--------------------------------------------------------------------------------
/assets/mug_skulls/1.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/mug_skulls/1.jpeg
--------------------------------------------------------------------------------
/assets/mug_skulls/2.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/mug_skulls/2.jpeg
--------------------------------------------------------------------------------
/assets/mug_skulls/3.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/mug_skulls/3.jpeg
--------------------------------------------------------------------------------
/assets/mug_skulls/4.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/mug_skulls/4.jpeg
--------------------------------------------------------------------------------
/assets/outputs/ti.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/outputs/ti.png
--------------------------------------------------------------------------------
/assets/outputs/xti_v1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/outputs/xti_v1.png
--------------------------------------------------------------------------------
/assets/paper.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/prompt-plus-pytorch/7862c3f9b99ef2258907e1126c08967dd05ed91a/assets/paper.png
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from PIL import Image
3 | import torch
4 | from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
5 | from diffusers.utils import is_xformers_available
6 | from prompt_plus import TextualInversionStableDiffusionPipeline, PPlusStableDiffusionPipeline
7 |
8 |
9 | def image_grid(imgs, rows, cols):
10 | assert len(imgs) == rows * cols
11 | w, h = imgs[0].size
12 | grid = Image.new('RGB', size=(cols * w, rows * h))
13 | for i, img in enumerate(imgs):
14 | grid.paste(img, box=(i % cols * w, i // cols * h))
15 | return grid
16 |
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser()
20 | parser.add_argument("--pretrained_model_name_or_path", type=str, help="model name or path", default="runwayml/stable-diffusion-v1-5")
21 | parser.add_argument("--learned_embed_name_or_path", type=str, help="model path for learned embedding")
22 | parser.add_argument("--is_textual_inversion", action="store_true", help="Load textual inversion embeds")
23 | parser.add_argument("--original_pipe", action="store_true", help="load standard pipeline")
24 | parser.add_argument("--device", type=str, help="Device on which Stable Diffusion will be run", choices=["cpu", "cuda"], default=None)
25 | parser.add_argument("--float16", action="store_true", help="load float16")
26 | # diffusers config
27 | parser.add_argument("--prompt", type=str, nargs="?", default="a photo of *s", help="the prompt to render")
28 | parser.add_argument("--num_inference_steps", type=int, default=30, help="number of ddim sampling steps")
29 | parser.add_argument("--guidance_scale", type=float, default=7.5, help="unconditional guidance scale")
30 | parser.add_argument("--num_images_per_prompt", type=int, default=3, help="number of images per prompt")
31 | parser.add_argument("--height", type=int, default=512, help="image height, in pixel space",)
32 | parser.add_argument("--width", type=int, default=512, help="image width, in pixel space",)
33 | parser.add_argument("--seed", type=int, default=None, help="the seed (for reproducible sampling)")
34 | opt = parser.parse_args()
35 | return opt
36 |
37 |
38 | def main():
39 | args = parse_args()
40 | if args.device is None:
41 | args.device = "cuda" if torch.cuda.is_available() else "cpu"
42 | print(f"device: {args.device}")
43 |
44 | # load model
45 | if args.is_textual_inversion or not args.original_pipe:
46 | if args.is_textual_inversion:
47 | Pipeline = TextualInversionStableDiffusionPipeline
48 | else:
49 | Pipeline = PPlusStableDiffusionPipeline
50 | pipe = Pipeline.from_learned_embed(
51 | pretrained_model_name_or_path=args.pretrained_model_name_or_path,
52 | learned_embed_name_or_path=args.learned_embed_name_or_path,
53 | torch_dtype=torch.float16 if args.float16 else None,
54 | ).to(args.device)
55 | else:
56 | print("loading the original pipeline")
57 | pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, torch_dtype=torch.float16 if args.float16 else None).to(args.device)
58 | pipe.scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
59 | if is_xformers_available():
60 | pipe.enable_xformers_memory_efficient_attention()
61 | print("loaded pipeline")
62 | # run!
63 | generator = None
64 | if args.seed:
65 | print(f"Using seed: {args.seed}")
66 | generator = torch.Generator(device=args.device).manual_seed(args.seed)
67 | images = pipe(
68 | args.prompt,
69 | num_inference_steps=args.num_inference_steps,
70 | guidance_scale=args.guidance_scale,
71 | generator=generator,
72 | num_images_per_prompt=args.num_images_per_prompt,
73 | height=args.height,
74 | width=args.width
75 | ).images
76 | grid_image = image_grid(images, 1, args.num_images_per_prompt)
77 | grid_image.save("grid.png")
78 | print("DONE!")
79 |
80 |
81 | if __name__ == '__main__':
82 | main()
83 |
84 |
--------------------------------------------------------------------------------
/prompt_plus/__init__.py:
--------------------------------------------------------------------------------
1 | from .prompt_plus_unet_2d_condition import PPlusUNet2DConditionModel
2 | from .prompt_plus_pipeline_stable_diffusion import PPlusStableDiffusionPipeline, TextualInversionStableDiffusionPipeline
3 |
--------------------------------------------------------------------------------
/prompt_plus/prompt_plus_pipeline_stable_diffusion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from typing import Optional, List, Union, Callable, Dict, Any, Tuple
4 | import torch
5 | from transformers import CLIPTextModel, CLIPTokenizer
6 | from diffusers import StableDiffusionPipeline
7 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
8 | from diffusers.utils import logging
9 | from huggingface_hub import hf_hub_download
10 | from prompt_plus.prompt_plus_unet_2d_condition import PPlusUNet2DConditionModel
11 |
12 |
13 | logger = logging.get_logger(__name__)
14 |
15 |
16 | class TextualInversionStableDiffusionPipeline(StableDiffusionPipeline):
17 | @classmethod
18 | def from_learned_embed(
19 | cls,
20 | pretrained_model_name_or_path: Union[str, os.PathLike],
21 | learned_embed_name_or_path: Union[str, os.PathLike],
22 | **kwargs
23 | ):
24 | if os.path.exists(learned_embed_name_or_path):
25 | embeds_path = os.path.join(learned_embed_name_or_path, "learned_embeds.bin") if os.path.isdir(learned_embed_name_or_path) else learned_embed_name_or_path
26 | # token_path = os.path.join(model_dir, "token_identifier.txt")
27 | else:
28 | # download
29 | embeds_path = hf_hub_download(repo_id=learned_embed_name_or_path, filename="learned_embeds.bin")
30 | # token_path = hf_hub_download(repo_id=learned_embed_name_or_path, filename="token_identifier.txt")
31 |
32 | text_encoder = CLIPTextModel.from_pretrained(
33 | pretrained_model_name_or_path, subfolder="text_encoder", **kwargs
34 | )
35 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", **kwargs)
36 | loaded_learned_embeds = torch.load(embeds_path, map_location="cpu")
37 | # separate token and the embeds
38 | trained_token = list(loaded_learned_embeds.keys())[0]
39 | embeds = loaded_learned_embeds[trained_token]
40 |
41 | # cast to dtype of text_encoder
42 | dtype = text_encoder.get_input_embeddings().weight.dtype
43 | embeds.to(dtype)
44 |
45 | # add the token in tokenizer
46 | # token = token if token is not None else trained_token
47 | num_added_tokens = tokenizer.add_tokens(trained_token)
48 | if num_added_tokens == 0:
49 | raise ValueError(
50 | f"The tokenizer already contains the token {trained_token}. Please pass a different `token` that is not already in the tokenizer.")
51 |
52 | # resize the token embeddings
53 | text_encoder.resize_token_embeddings(len(tokenizer))
54 |
55 | # get the id for the token and assign the embeds
56 | token_id = tokenizer.convert_tokens_to_ids(trained_token)
57 | text_encoder.get_input_embeddings().weight.data[token_id] = embeds
58 | print(f"placeholder_token: {trained_token}")
59 | return super().from_pretrained(
60 | pretrained_model_name_or_path=pretrained_model_name_or_path,
61 | text_encoder=text_encoder,
62 | tokenizer=tokenizer,
63 | **kwargs
64 | )
65 |
66 |
67 | def _load_embed_from_name_or_path(learned_embed_name_or_path):
68 | if os.path.exists(learned_embed_name_or_path):
69 | embeds_path = os.path.join(learned_embed_name_or_path, "learned_embeds.bin") if os.path.isdir(
70 | learned_embed_name_or_path) else learned_embed_name_or_path
71 | # config_path = os.path.join(model_dir, "config.json")
72 | else:
73 | # download
74 | embeds_path = hf_hub_download(repo_id=learned_embed_name_or_path, filename="learned_embeds.bin")
75 | # config_path = hf_hub_download(repo_id=pretrained_model_name_or_path, filename="config.json")
76 | # with open(config_path, "r", encoding="utf-8") as f:
77 | # config = json.load(f)
78 | # load
79 | loaded_learned_embeds = torch.load(embeds_path, map_location="cpu")
80 | return loaded_learned_embeds
81 |
82 |
83 | def load_embed_from_name_or_path(learned_embed_name_or_path, style_mixing_k_K=None):
84 | if isinstance(learned_embed_name_or_path, str):
85 | assert style_mixing_k_K is None, "You inputted only one learned embed but `style_mixing_k_K` was specified!"
86 | return _load_embed_from_name_or_path(learned_embed_name_or_path)
87 | else:
88 | assert len(learned_embed_name_or_path) == 2, "Only 2 embeds are supported for now but it's especially possible."
89 | k, K = style_mixing_k_K
90 | embeds = []
91 | for p in learned_embed_name_or_path:
92 | embeds.append(_load_embed_from_name_or_path(p))
93 | # use first embeds tokens to align
94 | tokens = list(embeds[0].keys())
95 | n = len(tokens)
96 | assert k < n, f"k must be lower than n={n}"
97 | assert K < n, f"K must be lower than n={n}"
98 | loaded_learned_embeds = dict()
99 | for i in range(n):
100 | if i <= k or K > i:
101 | embed_idx = 0
102 | else:
103 | embed_idx = 1
104 | embed = list(embeds[embed_idx].values())[i]
105 | loaded_learned_embeds[tokens[i]] = embed
106 | return loaded_learned_embeds
107 |
108 |
109 | class PPlusStableDiffusionPipeline(StableDiffusionPipeline):
110 | @classmethod
111 | def from_learned_embed(
112 | cls,
113 | pretrained_model_name_or_path: Union[str, os.PathLike],
114 | learned_embed_name_or_path: Optional[Union[str, os.PathLike, List[str]]] = None,
115 | style_mixing_k_K: Optional[Tuple[int]] = None,
116 | loaded_learned_embeds: Optional[Dict[str, torch.Tensor]] = None,
117 | **kwargs,
118 | ):
119 | text_encoder = CLIPTextModel.from_pretrained(
120 | pretrained_model_name_or_path, subfolder="text_encoder", **kwargs
121 | )
122 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer", **kwargs)
123 | if loaded_learned_embeds is None:
124 | loaded_learned_embeds = load_embed_from_name_or_path(learned_embed_name_or_path, style_mixing_k_K)
125 | new_tokens = list(loaded_learned_embeds.keys())
126 | # easy validation for textual inversion
127 | assert len(new_tokens) > 1, "You might want to load textual inversion pipeline!"
128 | # cast to dtype of text_encoder
129 | dtype = text_encoder.get_input_embeddings().weight.dtype
130 | # resize the token embeddings
131 | text_encoder.resize_token_embeddings(len(tokenizer)+len(new_tokens))
132 |
133 | for token in new_tokens:
134 | embeds = loaded_learned_embeds[token]
135 | embeds.to(dtype)
136 | # add the token in tokenizer
137 | # token = token if token is not None else trained_token
138 | num_added_tokens = tokenizer.add_tokens(token)
139 | if num_added_tokens == 0:
140 | raise ValueError(
141 | f"The tokenizer already contains the token {token}. Please pass a different `token` that is not already in the tokenizer.")
142 | # get the id for the token and assign the embeds
143 | token_id = tokenizer.convert_tokens_to_ids(token)
144 | text_encoder.get_input_embeddings().weight.data[token_id] = loaded_learned_embeds[token]
145 | # store placeholder_token to text_encoder config
146 | text_encoder.config.placeholder_token = "-".join(new_tokens[0].split("-")[:-1])
147 | text_encoder.config.placeholder_tokens = new_tokens
148 | print(f"placeholder_token: {text_encoder.config.placeholder_token}")
149 | unet = PPlusUNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet", **kwargs)
150 | return super().from_pretrained(
151 | pretrained_model_name_or_path=pretrained_model_name_or_path,
152 | unet=unet,
153 | text_encoder=text_encoder,
154 | tokenizer=tokenizer,
155 | **kwargs
156 | )
157 |
158 | def _encode_prompt(
159 | self,
160 | prompt,
161 | device,
162 | num_images_per_prompt,
163 | do_classifier_free_guidance,
164 | negative_prompt=None,
165 | prompt_embeds: Optional[torch.FloatTensor] = None,
166 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
167 | ):
168 | assert isinstance(prompt, str), "Currently, only string `prompt` is supported!"
169 | if prompt is not None and isinstance(prompt, str):
170 | batch_size = 1
171 | elif prompt is not None and isinstance(prompt, list):
172 | batch_size = len(prompt)
173 | else:
174 | batch_size = prompt_embeds.shape[0]
175 |
176 | if prompt_embeds is None:
177 | encoder_hidden_states_list = []
178 | for token in self.text_encoder.config.placeholder_tokens:
179 | one_prompt = prompt.replace(self.text_encoder.config.placeholder_token, token)
180 | text_inputs = self.tokenizer(
181 | one_prompt,
182 | padding="max_length",
183 | max_length=self.tokenizer.model_max_length,
184 | truncation=True,
185 | return_tensors="pt",
186 | )
187 | text_input_ids = text_inputs.input_ids
188 | untruncated_ids = self.tokenizer(one_prompt, padding="longest", return_tensors="pt").input_ids
189 |
190 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
191 | text_input_ids, untruncated_ids
192 | ):
193 | removed_text = self.tokenizer.batch_decode(
194 | untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
195 | )
196 | logger.warning(
197 | "The following part of your input was truncated because CLIP can only handle sequences up to"
198 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
199 | )
200 |
201 | if hasattr(self.text_encoder.config,
202 | "use_attention_mask") and self.text_encoder.config.use_attention_mask:
203 | attention_mask = text_inputs.attention_mask.to(device)
204 | else:
205 | attention_mask = None
206 |
207 | prompt_embeds = self.text_encoder(
208 | text_input_ids.to(device),
209 | attention_mask=attention_mask,
210 | )
211 | prompt_embeds = prompt_embeds[0]
212 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
213 |
214 | bs_embed, seq_len, _ = prompt_embeds.shape
215 | # duplicate text embeddings for each generation per prompt, using mps friendly method
216 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
217 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
218 |
219 | # get unconditional embeddings for classifier free guidance
220 | if do_classifier_free_guidance:
221 | uncond_tokens: List[str]
222 | if negative_prompt is None:
223 | uncond_tokens = [""] * batch_size
224 | elif type(prompt) is not type(negative_prompt):
225 | raise TypeError(
226 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
227 | f" {type(prompt)}."
228 | )
229 | elif isinstance(negative_prompt, str):
230 | uncond_tokens = [negative_prompt]
231 | elif batch_size != len(negative_prompt):
232 | raise ValueError(
233 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
234 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
235 | " the batch size of `prompt`."
236 | )
237 | else:
238 | uncond_tokens = negative_prompt
239 |
240 | max_length = prompt_embeds.shape[1]
241 | uncond_input = self.tokenizer(
242 | uncond_tokens,
243 | padding="max_length",
244 | max_length=max_length,
245 | truncation=True,
246 | return_tensors="pt",
247 | )
248 |
249 | if hasattr(self.text_encoder.config,
250 | "use_attention_mask") and self.text_encoder.config.use_attention_mask:
251 | attention_mask = uncond_input.attention_mask.to(device)
252 | else:
253 | attention_mask = None
254 |
255 | negative_prompt_embeds = self.text_encoder(
256 | uncond_input.input_ids.to(device),
257 | attention_mask=attention_mask,
258 | )
259 | negative_prompt_embeds = negative_prompt_embeds[0]
260 |
261 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
262 | seq_len = negative_prompt_embeds.shape[1]
263 |
264 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
265 |
266 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
267 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len,
268 | -1)
269 |
270 | # For classifier free guidance, we need to do two forward passes.
271 | # Here we concatenate the unconditional and text embeddings into a single batch
272 | # to avoid doing two forward passes
273 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
274 |
275 | encoder_hidden_states_list.append(prompt_embeds)
276 | else:
277 | # trust you!
278 | encoder_hidden_states_list = prompt_embeds
279 | return encoder_hidden_states_list
280 |
281 | @torch.no_grad()
282 | def __call__(
283 | self,
284 | prompt: Union[str, List[str]] = None,
285 | height: Optional[int] = None,
286 | width: Optional[int] = None,
287 | num_inference_steps: int = 50,
288 | guidance_scale: float = 7.5,
289 | negative_prompt: Optional[Union[str, List[str]]] = None,
290 | num_images_per_prompt: Optional[int] = 1,
291 | eta: float = 0.0,
292 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
293 | latents: Optional[torch.FloatTensor] = None,
294 | prompt_embeds: Optional[torch.FloatTensor] = None,
295 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
296 | output_type: Optional[str] = "pil",
297 | return_dict: bool = True,
298 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
299 | callback_steps: int = 1,
300 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
301 | ):
302 | # 0. Default height and width to unet
303 | height = height or self.unet.config.sample_size * self.vae_scale_factor
304 | width = width or self.unet.config.sample_size * self.vae_scale_factor
305 |
306 | # 1. Check inputs. Raise error if not correct
307 | self.check_inputs(
308 | prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
309 | )
310 |
311 | # 2. Define call parameters
312 | if prompt is not None and isinstance(prompt, str):
313 | batch_size = 1
314 | elif prompt is not None and isinstance(prompt, list):
315 | batch_size = len(prompt)
316 | else:
317 | batch_size = prompt_embeds.shape[0]
318 |
319 | device = self._execution_device
320 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
321 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
322 | # corresponds to doing no classifier free guidance.
323 | do_classifier_free_guidance = guidance_scale > 1.0
324 |
325 | # 3. Encode input prompt
326 | encoder_hidden_states_list = self._encode_prompt(
327 | prompt,
328 | device,
329 | num_images_per_prompt,
330 | do_classifier_free_guidance,
331 | negative_prompt,
332 | prompt_embeds=prompt_embeds,
333 | negative_prompt_embeds=negative_prompt_embeds,
334 | )
335 |
336 | # 4. Prepare timesteps
337 | self.scheduler.set_timesteps(num_inference_steps, device=device)
338 | timesteps = self.scheduler.timesteps
339 |
340 | # 5. Prepare latent variables
341 | num_channels_latents = self.unet.in_channels
342 | latents = self.prepare_latents(
343 | batch_size * num_images_per_prompt,
344 | num_channels_latents,
345 | height,
346 | width,
347 | encoder_hidden_states_list[0].dtype,
348 | device,
349 | generator,
350 | latents,
351 | )
352 |
353 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
354 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
355 |
356 | # 7. Denoising loop
357 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
358 | with self.progress_bar(total=num_inference_steps) as progress_bar:
359 | for i, t in enumerate(timesteps):
360 | # expand the latents if we are doing classifier free guidance
361 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
362 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
363 |
364 | # predict the noise residual
365 | noise_pred = self.unet(
366 | latent_model_input,
367 | t,
368 | encoder_hidden_states_list=encoder_hidden_states_list,
369 | cross_attention_kwargs=cross_attention_kwargs,
370 | ).sample
371 |
372 | # perform guidance
373 | if do_classifier_free_guidance:
374 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
375 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
376 |
377 | # compute the previous noisy sample x_t -> x_t-1
378 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
379 |
380 | # call the callback, if provided
381 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
382 | progress_bar.update()
383 | if callback is not None and i % callback_steps == 0:
384 | callback(i, t, latents)
385 |
386 | if output_type == "latent":
387 | image = latents
388 | has_nsfw_concept = None
389 | elif output_type == "pil":
390 | # 8. Post-processing
391 | image = self.decode_latents(latents)
392 |
393 | # 9. Run safety checker
394 | image, has_nsfw_concept = self.run_safety_checker(image, device, encoder_hidden_states_list[0].dtype)
395 |
396 | # 10. Convert to PIL
397 | image = self.numpy_to_pil(image)
398 | else:
399 | # 8. Post-processing
400 | image = self.decode_latents(latents)
401 |
402 | # 9. Run safety checker
403 | image, has_nsfw_concept = self.run_safety_checker(image, device, encoder_hidden_states_list[0].dtype)
404 |
405 | # Offload last model to CPU
406 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
407 | self.final_offload_hook.offload()
408 |
409 | if not return_dict:
410 | return (image, has_nsfw_concept)
411 |
412 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
413 |
414 |
--------------------------------------------------------------------------------
/prompt_plus/prompt_plus_unet_2d_condition.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, Any, Dict, Tuple, List
2 | import torch
3 | from diffusers import UNet2DConditionModel
4 | from diffusers.models.unet_2d_condition import UNet2DConditionOutput
5 | from diffusers.utils import logging
6 |
7 |
8 | logger = logging.get_logger(__name__)
9 |
10 |
11 | class PPlusUNet2DConditionModel(UNet2DConditionModel):
12 | def forward(
13 | self,
14 | sample: torch.FloatTensor,
15 | timestep: Union[torch.Tensor, float, int],
16 | encoder_hidden_states: torch.Tensor = None,
17 | #########################################
18 | encoder_hidden_states_list: List[torch.Tensor] = None,
19 | #########################################
20 | class_labels: Optional[torch.Tensor] = None,
21 | timestep_cond: Optional[torch.Tensor] = None,
22 | attention_mask: Optional[torch.Tensor] = None,
23 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
24 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
25 | mid_block_additional_residual: Optional[torch.Tensor] = None,
26 | return_dict: bool = True,
27 | ):
28 | if encoder_hidden_states is None and encoder_hidden_states_list is None:
29 | raise ValueError("You must input either `encoder_hidden_states` or `encoder_hidden_states_list`!")
30 | if encoder_hidden_states_list is not None:
31 | select_idx = 0
32 |
33 | # By default samples have to be AT least a multiple of the overall upsampling factor.
34 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
35 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
36 | # on the fly if necessary.
37 | default_overall_up_factor = 2 ** self.num_upsamplers
38 |
39 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
40 | forward_upsample_size = False
41 | upsample_size = None
42 |
43 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
44 | logger.info("Forward upsample size to force interpolation output size.")
45 | forward_upsample_size = True
46 |
47 | # prepare attention_mask
48 | if attention_mask is not None:
49 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
50 | attention_mask = attention_mask.unsqueeze(1)
51 |
52 | # 0. center input if necessary
53 | if self.config.center_input_sample:
54 | sample = 2 * sample - 1.0
55 |
56 | # 1. time
57 | timesteps = timestep
58 | if not torch.is_tensor(timesteps):
59 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
60 | # This would be a good case for the `match` statement (Python 3.10+)
61 | is_mps = sample.device.type == "mps"
62 | if isinstance(timestep, float):
63 | dtype = torch.float32 if is_mps else torch.float64
64 | else:
65 | dtype = torch.int32 if is_mps else torch.int64
66 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
67 | elif len(timesteps.shape) == 0:
68 | timesteps = timesteps[None].to(sample.device)
69 |
70 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
71 | timesteps = timesteps.expand(sample.shape[0])
72 |
73 | t_emb = self.time_proj(timesteps)
74 |
75 | # timesteps does not contain any weights and will always return f32 tensors
76 | # but time_embedding might actually be running in fp16. so we need to cast here.
77 | # there might be better ways to encapsulate this.
78 | t_emb = t_emb.to(dtype=self.dtype)
79 |
80 | emb = self.time_embedding(t_emb, timestep_cond)
81 |
82 | if self.class_embedding is not None:
83 | if class_labels is None:
84 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
85 |
86 | if self.config.class_embed_type == "timestep":
87 | class_labels = self.time_proj(class_labels)
88 |
89 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
90 | emb = emb + class_emb
91 |
92 | # 2. pre-process
93 | sample = self.conv_in(sample)
94 |
95 | # 3. down
96 | down_block_res_samples = (sample,)
97 | for downsample_block in self.down_blocks:
98 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
99 | sample, res_samples = downsample_block(
100 | hidden_states=sample,
101 | temb=emb,
102 | encoder_hidden_states=encoder_hidden_states if encoder_hidden_states_list is None else encoder_hidden_states_list[select_idx],
103 | attention_mask=attention_mask,
104 | cross_attention_kwargs=cross_attention_kwargs,
105 | )
106 | if encoder_hidden_states_list is not None:
107 | select_idx += 1
108 | else:
109 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
110 |
111 | down_block_res_samples += res_samples
112 |
113 | if down_block_additional_residuals is not None:
114 | new_down_block_res_samples = ()
115 |
116 | for down_block_res_sample, down_block_additional_residual in zip(
117 | down_block_res_samples, down_block_additional_residuals
118 | ):
119 | down_block_res_sample = down_block_res_sample + down_block_additional_residual
120 | new_down_block_res_samples += (down_block_res_sample,)
121 |
122 | down_block_res_samples = new_down_block_res_samples
123 |
124 | # 4. mid
125 | if self.mid_block is not None:
126 | sample = self.mid_block(
127 | sample,
128 | emb,
129 | encoder_hidden_states=encoder_hidden_states if encoder_hidden_states_list is None else
130 | encoder_hidden_states_list[select_idx],
131 | attention_mask=attention_mask,
132 | cross_attention_kwargs=cross_attention_kwargs,
133 | )
134 | if encoder_hidden_states_list is not None:
135 | select_idx += 1
136 |
137 | if mid_block_additional_residual is not None:
138 | sample = sample + mid_block_additional_residual
139 |
140 | # 5. up
141 | for i, upsample_block in enumerate(self.up_blocks):
142 | is_final_block = i == len(self.up_blocks) - 1
143 |
144 | res_samples = down_block_res_samples[-len(upsample_block.resnets):]
145 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
146 |
147 | # if we have not reached the final block and need to forward the
148 | # upsample size, we do it here
149 | if not is_final_block and forward_upsample_size:
150 | upsample_size = down_block_res_samples[-1].shape[2:]
151 |
152 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
153 | sample = upsample_block(
154 | hidden_states=sample,
155 | temb=emb,
156 | res_hidden_states_tuple=res_samples,
157 | encoder_hidden_states=encoder_hidden_states if encoder_hidden_states_list is None else
158 | encoder_hidden_states_list[select_idx],
159 | cross_attention_kwargs=cross_attention_kwargs,
160 | upsample_size=upsample_size,
161 | attention_mask=attention_mask,
162 | )
163 | if encoder_hidden_states_list is not None:
164 | select_idx += 1
165 | else:
166 | sample = upsample_block(
167 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
168 | )
169 |
170 | # 6. post-process
171 | if self.conv_norm_out:
172 | sample = self.conv_norm_out(sample)
173 | sample = self.conv_act(sample)
174 | sample = self.conv_out(sample)
175 |
176 | if not return_dict:
177 | return (sample,)
178 |
179 | return UNet2DConditionOutput(sample=sample)
180 |
181 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers[torch]
2 | accelerate
3 | torchvision
4 | transformers>=4.25.1
5 | ftfy
6 | tensorboard
7 | Jinja2
8 | wandb
9 | natsort
10 | safetensors
11 | datasets
12 | bitsandbytes
--------------------------------------------------------------------------------
/scripts/app.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | from functools import lru_cache
4 | import subprocess
5 | import torch
6 | import gradio as gr
7 | from diffusers.utils import is_xformers_available
8 |
9 |
10 | device = "cuda" if torch.cuda.is_available() else "cpu"
11 | print(f"device: {device}")
12 |
13 |
14 | def gitclone(url, target_dir=None, branch_arg=None):
15 | run_args = ["git", "clone"]
16 | if branch_arg:
17 | run_args.extend(["-b", branch_arg])
18 | run_args.append(url)
19 | if target_dir:
20 | run_args.append(target_dir)
21 | res = subprocess.run(run_args, stdout=subprocess.PIPE).stdout.decode("utf-8")
22 | print(res)
23 |
24 |
25 | def pipi(modulestr):
26 | res = subprocess.run(
27 | ["pip", "install", modulestr], stdout=subprocess.PIPE
28 | ).stdout.decode("utf-8")
29 | print(res)
30 |
31 |
32 | try:
33 | proj_dir = os.path.dirname(__file__)
34 | sys.path.append(proj_dir)
35 | from prompt_plus import PPlusStableDiffusionPipeline
36 | except ImportError:
37 | GITHUB_SECRET = os.environ.get("GITHUB_SECRET")
38 | gitclone("https://github.com/mkshing/prompt-plus-pytorch" if GITHUB_SECRET is None else f"https://{GITHUB_SECRET}@github.com/mkshing/prompt-plus-pytorch")
39 | from prompt_plus import PPlusStableDiffusionPipeline
40 |
41 |
42 | @lru_cache(maxsize=3)
43 | def load_pipe(pretrained_model_name_or_path, learned_embed_name_or_path):
44 | pipe = PPlusStableDiffusionPipeline.from_learned_embed(
45 | pretrained_model_name_or_path=pretrained_model_name_or_path,
46 | learned_embed_name_or_path=learned_embed_name_or_path,
47 | revision="fp16", torch_dtype=torch.float16
48 | )
49 | if is_xformers_available():
50 | pipe.enable_xformers_memory_efficient_attention()
51 | return pipe
52 |
53 |
54 | def txt2img_func(pretrained_model_name_or_path, learned_embed_name_or_path, prompt, n_samples=4, scale=7.5, steps=25, width=512, height=512, seed="random"):
55 | n_samples = int(n_samples)
56 | scale = float(scale)
57 | steps = int(steps)
58 | width = int(width)
59 | height = int(height)
60 | generator = torch.Generator(device=device)
61 | if seed == "random":
62 | seed = generator.seed()
63 | else:
64 | seed = int(seed)
65 | generator = generator.manual_seed(int(seed))
66 | pipe = load_pipe(pretrained_model_name_or_path, learned_embed_name_or_path).to(device)
67 | images = pipe(
68 | prompt,
69 | num_inference_steps=steps,
70 | guidance_scale=scale,
71 | generator=generator,
72 | num_images_per_prompt=n_samples,
73 | height=height,
74 | width=width
75 | ).images
76 | return images
77 |
78 |
79 | with gr.Blocks() as demo:
80 | gr.Markdown("# P+: Extended Textual Conditioning in Text-to-Image Generation")
81 | pretrained_model_name_or_path = gr.Textbox(label="pre-trained model name or path", value="runwayml/stable-diffusion-v1-5")
82 | learned_embed_name_or_path = gr.Textbox(label="learned embedding name or path")
83 | with gr.Row():
84 | with gr.Column():
85 | # input
86 | prompt = gr.Textbox(label="Prompt")
87 | n_samples = gr.Number(value=3, label="n_samples")
88 | cfg_scale = gr.Slider(minimum=0.0, maximum=20, value=7.5, label="cfg_scale", step=0.5)
89 | steps = gr.Number(value=30, label="steps")
90 | width = gr.Slider(minimum=128, maximum=1024, value=512, label="width", step=64)
91 | height = gr.Slider(minimum=128, maximum=1024, value=512, label="height", step=64)
92 | seed = gr.Textbox(value='random',
93 | placeholder="If you fix seed, you get same outputs all the time. You can set as integer like 42.",
94 | label="seed")
95 |
96 | # button
97 | button = gr.Button(value="Generate!")
98 | with gr.Column():
99 | # output
100 | out_images = gr.Gallery(label="Output")
101 | button.click(
102 | txt2img_func,
103 | inputs=[pretrained_model_name_or_path, learned_embed_name_or_path, prompt, n_samples, cfg_scale, steps, width, height, seed],
104 | outputs=[out_images],
105 | api_name="txt2img"
106 | )
107 |
108 | demo.launch()
109 |
--------------------------------------------------------------------------------
/scripts/textual_inversion.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 |
16 | import argparse
17 | import logging
18 | import math
19 | import os
20 | import random
21 | import warnings
22 | from pathlib import Path
23 | from typing import Optional
24 |
25 | import numpy as np
26 | import PIL
27 | import torch
28 | import torch.nn.functional as F
29 | import torch.utils.checkpoint
30 | import transformers
31 | from accelerate import Accelerator
32 | from accelerate.logging import get_logger
33 | from accelerate.utils import ProjectConfiguration, set_seed
34 | from huggingface_hub import HfFolder, Repository, create_repo, whoami
35 |
36 | # TODO: remove and import from diffusers.utils when the new version of diffusers is released
37 | from packaging import version
38 | from PIL import Image
39 | from torch.utils.data import Dataset
40 | from torchvision import transforms
41 | from tqdm.auto import tqdm
42 | from transformers import CLIPTextModel, CLIPTokenizer
43 |
44 | import diffusers
45 | from diffusers import (
46 | AutoencoderKL,
47 | DDPMScheduler,
48 | DiffusionPipeline,
49 | DPMSolverMultistepScheduler,
50 | StableDiffusionPipeline,
51 | UNet2DConditionModel,
52 | )
53 | from diffusers.optimization import get_scheduler
54 | from diffusers.utils import check_min_version, is_wandb_available
55 | from diffusers.utils.import_utils import is_xformers_available
56 |
57 |
58 | if is_wandb_available():
59 | import wandb
60 |
61 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
62 | PIL_INTERPOLATION = {
63 | "linear": PIL.Image.Resampling.BILINEAR,
64 | "bilinear": PIL.Image.Resampling.BILINEAR,
65 | "bicubic": PIL.Image.Resampling.BICUBIC,
66 | "lanczos": PIL.Image.Resampling.LANCZOS,
67 | "nearest": PIL.Image.Resampling.NEAREST,
68 | }
69 | else:
70 | PIL_INTERPOLATION = {
71 | "linear": PIL.Image.LINEAR,
72 | "bilinear": PIL.Image.BILINEAR,
73 | "bicubic": PIL.Image.BICUBIC,
74 | "lanczos": PIL.Image.LANCZOS,
75 | "nearest": PIL.Image.NEAREST,
76 | }
77 | # ------------------------------------------------------------------------------
78 |
79 |
80 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
81 | # check_min_version("0.15.0.dev0")
82 |
83 | logger = get_logger(__name__)
84 |
85 |
86 | def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
87 | logger.info(
88 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
89 | f" {args.validation_prompt}."
90 | )
91 | # create pipeline (note: unet and vae are loaded again in float32)
92 | pipeline = DiffusionPipeline.from_pretrained(
93 | args.pretrained_model_name_or_path,
94 | text_encoder=accelerator.unwrap_model(text_encoder),
95 | tokenizer=tokenizer,
96 | unet=unet,
97 | vae=vae,
98 | revision=args.revision,
99 | torch_dtype=weight_dtype,
100 | )
101 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
102 | pipeline = pipeline.to(accelerator.device)
103 | pipeline.set_progress_bar_config(disable=True)
104 |
105 | # run inference
106 | generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
107 | images = []
108 | for _ in range(args.num_validation_images):
109 | with torch.autocast("cuda"):
110 | image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
111 | images.append(image)
112 |
113 | for tracker in accelerator.trackers:
114 | if tracker.name == "tensorboard":
115 | np_images = np.stack([np.asarray(img) for img in images])
116 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
117 | if tracker.name == "wandb":
118 | tracker.log(
119 | {
120 | "validation": [
121 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
122 | ]
123 | }
124 | )
125 |
126 | del pipeline
127 | torch.cuda.empty_cache()
128 |
129 |
130 | def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
131 | logger.info("Saving embeddings")
132 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
133 | learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
134 | torch.save(learned_embeds_dict, save_path)
135 |
136 |
137 | def parse_args():
138 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
139 | parser.add_argument(
140 | "--save_steps",
141 | type=int,
142 | default=500,
143 | help="Save learned_embeds.bin every X updates steps.",
144 | )
145 | parser.add_argument(
146 | "--only_save_embeds",
147 | action="store_true",
148 | default=False,
149 | help="Save only the embeddings for the new concept.",
150 | )
151 | parser.add_argument(
152 | "--pretrained_model_name_or_path",
153 | type=str,
154 | default=None,
155 | required=True,
156 | help="Path to pretrained model or model identifier from huggingface.co/models.",
157 | )
158 | parser.add_argument(
159 | "--revision",
160 | type=str,
161 | default=None,
162 | required=False,
163 | help="Revision of pretrained model identifier from huggingface.co/models.",
164 | )
165 | parser.add_argument(
166 | "--tokenizer_name",
167 | type=str,
168 | default=None,
169 | help="Pretrained tokenizer name or path if not the same as model_name",
170 | )
171 | parser.add_argument(
172 | "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
173 | )
174 | parser.add_argument(
175 | "--placeholder_token",
176 | type=str,
177 | default=None,
178 | required=True,
179 | help="A token to use as a placeholder for the concept.",
180 | )
181 | parser.add_argument(
182 | "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
183 | )
184 | parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
185 | parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
186 | parser.add_argument(
187 | "--output_dir",
188 | type=str,
189 | default="text-inversion-model",
190 | help="The output directory where the model predictions and checkpoints will be written.",
191 | )
192 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
193 | parser.add_argument(
194 | "--resolution",
195 | type=int,
196 | default=512,
197 | help=(
198 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
199 | " resolution"
200 | ),
201 | )
202 | parser.add_argument(
203 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
204 | )
205 | parser.add_argument(
206 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
207 | )
208 | parser.add_argument("--num_train_epochs", type=int, default=100)
209 | parser.add_argument(
210 | "--max_train_steps",
211 | type=int,
212 | default=5000,
213 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
214 | )
215 | parser.add_argument(
216 | "--gradient_accumulation_steps",
217 | type=int,
218 | default=1,
219 | help="Number of updates steps to accumulate before performing a backward/update pass.",
220 | )
221 | parser.add_argument(
222 | "--gradient_checkpointing",
223 | action="store_true",
224 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
225 | )
226 | parser.add_argument(
227 | "--learning_rate",
228 | type=float,
229 | default=1e-4,
230 | help="Initial learning rate (after the potential warmup period) to use.",
231 | )
232 | parser.add_argument(
233 | "--scale_lr",
234 | action="store_true",
235 | default=False,
236 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
237 | )
238 | parser.add_argument(
239 | "--lr_scheduler",
240 | type=str,
241 | default="constant",
242 | help=(
243 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
244 | ' "constant", "constant_with_warmup"]'
245 | ),
246 | )
247 | parser.add_argument(
248 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
249 | )
250 | parser.add_argument(
251 | "--dataloader_num_workers",
252 | type=int,
253 | default=0,
254 | help=(
255 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
256 | ),
257 | )
258 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
259 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
260 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
261 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
262 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
263 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
264 | parser.add_argument(
265 | "--hub_model_id",
266 | type=str,
267 | default=None,
268 | help="The name of the repository to keep in sync with the local `output_dir`.",
269 | )
270 | parser.add_argument(
271 | "--logging_dir",
272 | type=str,
273 | default="logs",
274 | help=(
275 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
276 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
277 | ),
278 | )
279 | parser.add_argument(
280 | "--mixed_precision",
281 | type=str,
282 | default="no",
283 | choices=["no", "fp16", "bf16"],
284 | help=(
285 | "Whether to use mixed precision. Choose"
286 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
287 | "and an Nvidia Ampere GPU."
288 | ),
289 | )
290 | parser.add_argument(
291 | "--allow_tf32",
292 | action="store_true",
293 | help=(
294 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
295 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
296 | ),
297 | )
298 | parser.add_argument(
299 | "--report_to",
300 | type=str,
301 | default="tensorboard",
302 | help=(
303 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
304 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
305 | ),
306 | )
307 | parser.add_argument(
308 | "--validation_prompt",
309 | type=str,
310 | default=None,
311 | help="A prompt that is used during validation to verify that the model is learning.",
312 | )
313 | parser.add_argument(
314 | "--num_validation_images",
315 | type=int,
316 | default=4,
317 | help="Number of images that should be generated during validation with `validation_prompt`.",
318 | )
319 | parser.add_argument(
320 | "--validation_steps",
321 | type=int,
322 | default=100,
323 | help=(
324 | "Run validation every X steps. Validation consists of running the prompt"
325 | " `args.validation_prompt` multiple times: `args.num_validation_images`"
326 | " and logging the images."
327 | ),
328 | )
329 | parser.add_argument(
330 | "--validation_epochs",
331 | type=int,
332 | default=None,
333 | help=(
334 | "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
335 | " `args.validation_prompt` multiple times: `args.num_validation_images`"
336 | " and logging the images."
337 | ),
338 | )
339 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
340 | parser.add_argument(
341 | "--checkpointing_steps",
342 | type=int,
343 | default=500,
344 | help=(
345 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
346 | " training using `--resume_from_checkpoint`."
347 | ),
348 | )
349 | parser.add_argument(
350 | "--checkpoints_total_limit",
351 | type=int,
352 | default=None,
353 | help=(
354 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
355 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
356 | " for more docs"
357 | ),
358 | )
359 | parser.add_argument(
360 | "--resume_from_checkpoint",
361 | type=str,
362 | default=None,
363 | help=(
364 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
365 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
366 | ),
367 | )
368 | parser.add_argument(
369 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
370 | )
371 |
372 | args = parser.parse_args()
373 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
374 | if env_local_rank != -1 and env_local_rank != args.local_rank:
375 | args.local_rank = env_local_rank
376 |
377 | if args.train_data_dir is None:
378 | raise ValueError("You must specify a train data directory.")
379 |
380 | return args
381 |
382 |
383 | imagenet_templates_small = [
384 | "a photo of a {}",
385 | "a rendering of a {}",
386 | "a cropped photo of the {}",
387 | "the photo of a {}",
388 | "a photo of a clean {}",
389 | "a photo of a dirty {}",
390 | "a dark photo of the {}",
391 | "a photo of my {}",
392 | "a photo of the cool {}",
393 | "a close-up photo of a {}",
394 | "a bright photo of the {}",
395 | "a cropped photo of a {}",
396 | "a photo of the {}",
397 | "a good photo of the {}",
398 | "a photo of one {}",
399 | "a close-up photo of the {}",
400 | "a rendition of the {}",
401 | "a photo of the clean {}",
402 | "a rendition of a {}",
403 | "a photo of a nice {}",
404 | "a good photo of a {}",
405 | "a photo of the nice {}",
406 | "a photo of the small {}",
407 | "a photo of the weird {}",
408 | "a photo of the large {}",
409 | "a photo of a cool {}",
410 | "a photo of a small {}",
411 | ]
412 |
413 | imagenet_style_templates_small = [
414 | "a painting in the style of {}",
415 | "a rendering in the style of {}",
416 | "a cropped painting in the style of {}",
417 | "the painting in the style of {}",
418 | "a clean painting in the style of {}",
419 | "a dirty painting in the style of {}",
420 | "a dark painting in the style of {}",
421 | "a picture in the style of {}",
422 | "a cool painting in the style of {}",
423 | "a close-up painting in the style of {}",
424 | "a bright painting in the style of {}",
425 | "a cropped painting in the style of {}",
426 | "a good painting in the style of {}",
427 | "a close-up painting in the style of {}",
428 | "a rendition in the style of {}",
429 | "a nice painting in the style of {}",
430 | "a small painting in the style of {}",
431 | "a weird painting in the style of {}",
432 | "a large painting in the style of {}",
433 | ]
434 |
435 |
436 | class TextualInversionDataset(Dataset):
437 | def __init__(
438 | self,
439 | data_root,
440 | tokenizer,
441 | learnable_property="object", # [object, style]
442 | size=512,
443 | repeats=100,
444 | interpolation="bicubic",
445 | flip_p=0.5,
446 | set="train",
447 | placeholder_token="*",
448 | center_crop=False,
449 | ):
450 | self.data_root = data_root
451 | self.tokenizer = tokenizer
452 | self.learnable_property = learnable_property
453 | self.size = size
454 | self.placeholder_token = placeholder_token
455 | self.center_crop = center_crop
456 | self.flip_p = flip_p
457 |
458 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
459 |
460 | self.num_images = len(self.image_paths)
461 | self._length = self.num_images
462 |
463 | if set == "train":
464 | self._length = self.num_images * repeats
465 |
466 | self.interpolation = {
467 | "linear": PIL_INTERPOLATION["linear"],
468 | "bilinear": PIL_INTERPOLATION["bilinear"],
469 | "bicubic": PIL_INTERPOLATION["bicubic"],
470 | "lanczos": PIL_INTERPOLATION["lanczos"],
471 | }[interpolation]
472 |
473 | self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
474 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
475 |
476 | def __len__(self):
477 | return self._length
478 |
479 | def __getitem__(self, i):
480 | example = {}
481 | image = Image.open(self.image_paths[i % self.num_images])
482 |
483 | if not image.mode == "RGB":
484 | image = image.convert("RGB")
485 |
486 | placeholder_string = self.placeholder_token
487 | text = random.choice(self.templates).format(placeholder_string)
488 |
489 | example["input_ids"] = self.tokenizer(
490 | text,
491 | padding="max_length",
492 | truncation=True,
493 | max_length=self.tokenizer.model_max_length,
494 | return_tensors="pt",
495 | ).input_ids[0]
496 |
497 | # default to score-sde preprocessing
498 | img = np.array(image).astype(np.uint8)
499 |
500 | if self.center_crop:
501 | crop = min(img.shape[0], img.shape[1])
502 | (
503 | h,
504 | w,
505 | ) = (
506 | img.shape[0],
507 | img.shape[1],
508 | )
509 | img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
510 |
511 | image = Image.fromarray(img)
512 | image = image.resize((self.size, self.size), resample=self.interpolation)
513 |
514 | image = self.flip_transform(image)
515 | image = np.array(image).astype(np.uint8)
516 | image = (image / 127.5 - 1.0).astype(np.float32)
517 |
518 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
519 | return example
520 |
521 |
522 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
523 | if token is None:
524 | token = HfFolder.get_token()
525 | if organization is None:
526 | username = whoami(token)["name"]
527 | return f"{username}/{model_id}"
528 | else:
529 | return f"{organization}/{model_id}"
530 |
531 |
532 | def main():
533 | args = parse_args()
534 | logging_dir = os.path.join(args.output_dir, args.logging_dir)
535 |
536 | accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
537 |
538 | accelerator = Accelerator(
539 | gradient_accumulation_steps=args.gradient_accumulation_steps,
540 | mixed_precision=args.mixed_precision,
541 | log_with=args.report_to,
542 | logging_dir=logging_dir,
543 | project_config=accelerator_project_config,
544 | )
545 |
546 | if args.report_to == "wandb":
547 | if not is_wandb_available():
548 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
549 |
550 | # Make one log on every process with the configuration for debugging.
551 | logging.basicConfig(
552 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
553 | datefmt="%m/%d/%Y %H:%M:%S",
554 | level=logging.INFO,
555 | )
556 | logger.info(accelerator.state, main_process_only=False)
557 | if accelerator.is_local_main_process:
558 | transformers.utils.logging.set_verbosity_warning()
559 | diffusers.utils.logging.set_verbosity_info()
560 | else:
561 | transformers.utils.logging.set_verbosity_error()
562 | diffusers.utils.logging.set_verbosity_error()
563 |
564 | # If passed along, set the training seed now.
565 | if args.seed is not None:
566 | set_seed(args.seed)
567 |
568 | # Handle the repository creation
569 | if accelerator.is_main_process:
570 | if args.push_to_hub:
571 | if args.hub_model_id is None:
572 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
573 | else:
574 | repo_name = args.hub_model_id
575 | create_repo(repo_name, exist_ok=True, token=args.hub_token)
576 | repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
577 |
578 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
579 | if "step_*" not in gitignore:
580 | gitignore.write("step_*\n")
581 | if "epoch_*" not in gitignore:
582 | gitignore.write("epoch_*\n")
583 | elif args.output_dir is not None:
584 | os.makedirs(args.output_dir, exist_ok=True)
585 |
586 | # Load tokenizer
587 | if args.tokenizer_name:
588 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
589 | elif args.pretrained_model_name_or_path:
590 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
591 |
592 | # Load scheduler and models
593 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
594 | text_encoder = CLIPTextModel.from_pretrained(
595 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
596 | )
597 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
598 | unet = UNet2DConditionModel.from_pretrained(
599 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
600 | )
601 |
602 | # Add the placeholder token in tokenizer
603 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
604 | if num_added_tokens == 0:
605 | raise ValueError(
606 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
607 | " `placeholder_token` that is not already in the tokenizer."
608 | )
609 |
610 | # Convert the initializer_token, placeholder_token to ids
611 | token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
612 | # Check if initializer_token is a single token or a sequence of tokens
613 | if len(token_ids) > 1:
614 | raise ValueError("The initializer token must be a single token.")
615 |
616 | initializer_token_id = token_ids[0]
617 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
618 |
619 | # Resize the token embeddings as we are adding new special tokens to the tokenizer
620 | text_encoder.resize_token_embeddings(len(tokenizer))
621 |
622 | # Initialise the newly added placeholder token with the embeddings of the initializer token
623 | token_embeds = text_encoder.get_input_embeddings().weight.data
624 | token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
625 |
626 | # Freeze vae and unet
627 | vae.requires_grad_(False)
628 | unet.requires_grad_(False)
629 | # Freeze all parameters except for the token embeddings in text encoder
630 | text_encoder.text_model.encoder.requires_grad_(False)
631 | text_encoder.text_model.final_layer_norm.requires_grad_(False)
632 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
633 |
634 | if args.gradient_checkpointing:
635 | # Keep unet in train mode if we are using gradient checkpointing to save memory.
636 | # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
637 | unet.train()
638 | text_encoder.gradient_checkpointing_enable()
639 | unet.enable_gradient_checkpointing()
640 |
641 | if args.enable_xformers_memory_efficient_attention:
642 | if is_xformers_available():
643 | import xformers
644 |
645 | xformers_version = version.parse(xformers.__version__)
646 | if xformers_version == version.parse("0.0.16"):
647 | logger.warn(
648 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
649 | )
650 | unet.enable_xformers_memory_efficient_attention()
651 | else:
652 | raise ValueError("xformers is not available. Make sure it is installed correctly")
653 |
654 | # Enable TF32 for faster training on Ampere GPUs,
655 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
656 | if args.allow_tf32:
657 | torch.backends.cuda.matmul.allow_tf32 = True
658 |
659 | if args.scale_lr:
660 | args.learning_rate = (
661 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
662 | )
663 |
664 | # Initialize the optimizer
665 | optimizer = torch.optim.AdamW(
666 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
667 | lr=args.learning_rate,
668 | betas=(args.adam_beta1, args.adam_beta2),
669 | weight_decay=args.adam_weight_decay,
670 | eps=args.adam_epsilon,
671 | )
672 |
673 | # Dataset and DataLoaders creation:
674 | train_dataset = TextualInversionDataset(
675 | data_root=args.train_data_dir,
676 | tokenizer=tokenizer,
677 | size=args.resolution,
678 | placeholder_token=args.placeholder_token,
679 | repeats=args.repeats,
680 | learnable_property=args.learnable_property,
681 | center_crop=args.center_crop,
682 | set="train",
683 | )
684 | train_dataloader = torch.utils.data.DataLoader(
685 | train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
686 | )
687 | if args.validation_epochs is not None:
688 | warnings.warn(
689 | f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
690 | " Deprecated validation_epochs in favor of `validation_steps`"
691 | f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
692 | FutureWarning,
693 | stacklevel=2,
694 | )
695 | args.validation_steps = args.validation_epochs * len(train_dataset)
696 |
697 | # Scheduler and math around the number of training steps.
698 | overrode_max_train_steps = False
699 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
700 | if args.max_train_steps is None:
701 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
702 | overrode_max_train_steps = True
703 |
704 | lr_scheduler = get_scheduler(
705 | args.lr_scheduler,
706 | optimizer=optimizer,
707 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
708 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
709 | )
710 |
711 | # Prepare everything with our `accelerator`.
712 | text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
713 | text_encoder, optimizer, train_dataloader, lr_scheduler
714 | )
715 |
716 | # For mixed precision training we cast the unet and vae weights to half-precision
717 | # as these models are only used for inference, keeping weights in full precision is not required.
718 | weight_dtype = torch.float32
719 | if accelerator.mixed_precision == "fp16":
720 | weight_dtype = torch.float16
721 | elif accelerator.mixed_precision == "bf16":
722 | weight_dtype = torch.bfloat16
723 |
724 | # Move vae and unet to device and cast to weight_dtype
725 | unet.to(accelerator.device, dtype=weight_dtype)
726 | vae.to(accelerator.device, dtype=weight_dtype)
727 |
728 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
729 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
730 | if overrode_max_train_steps:
731 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
732 | # Afterwards we recalculate our number of training epochs
733 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
734 |
735 | # We need to initialize the trackers we use, and also store our configuration.
736 | # The trackers initializes automatically on the main process.
737 | if accelerator.is_main_process:
738 | accelerator.init_trackers("p_plust_xti", config=vars(args))
739 |
740 | # Train!
741 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
742 |
743 | logger.info("***** Running training *****")
744 | logger.info(f" Num examples = {len(train_dataset)}")
745 | logger.info(f" Num Epochs = {args.num_train_epochs}")
746 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
747 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
748 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
749 | logger.info(f" Total optimization steps = {args.max_train_steps}")
750 | global_step = 0
751 | first_epoch = 0
752 | # Potentially load in the weights and states from a previous save
753 | if args.resume_from_checkpoint:
754 | if args.resume_from_checkpoint != "latest":
755 | path = os.path.basename(args.resume_from_checkpoint)
756 | else:
757 | # Get the most recent checkpoint
758 | dirs = os.listdir(args.output_dir)
759 | dirs = [d for d in dirs if d.startswith("checkpoint")]
760 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
761 | path = dirs[-1] if len(dirs) > 0 else None
762 |
763 | if path is None:
764 | accelerator.print(
765 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
766 | )
767 | args.resume_from_checkpoint = None
768 | else:
769 | accelerator.print(f"Resuming from checkpoint {path}")
770 | accelerator.load_state(os.path.join(args.output_dir, path))
771 | global_step = int(path.split("-")[1])
772 |
773 | resume_global_step = global_step * args.gradient_accumulation_steps
774 | first_epoch = global_step // num_update_steps_per_epoch
775 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
776 |
777 | # Only show the progress bar once on each machine.
778 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
779 | progress_bar.set_description("Steps")
780 |
781 | # keep original embeddings as reference
782 | orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
783 |
784 | for epoch in range(first_epoch, args.num_train_epochs):
785 | text_encoder.train()
786 | for step, batch in enumerate(train_dataloader):
787 | # Skip steps until we reach the resumed step
788 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
789 | if step % args.gradient_accumulation_steps == 0:
790 | progress_bar.update(1)
791 | continue
792 |
793 | with accelerator.accumulate(text_encoder):
794 | # Convert images to latent space
795 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
796 | latents = latents * vae.config.scaling_factor
797 |
798 | # Sample noise that we'll add to the latents
799 | noise = torch.randn_like(latents)
800 | bsz = latents.shape[0]
801 | # Sample a random timestep for each image
802 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
803 | timesteps = timesteps.long()
804 |
805 | # Add noise to the latents according to the noise magnitude at each timestep
806 | # (this is the forward diffusion process)
807 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
808 |
809 | # Get the text embedding for conditioning
810 | encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
811 |
812 | # Predict the noise residual
813 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
814 |
815 | # Get the target for loss depending on the prediction type
816 | if noise_scheduler.config.prediction_type == "epsilon":
817 | target = noise
818 | elif noise_scheduler.config.prediction_type == "v_prediction":
819 | target = noise_scheduler.get_velocity(latents, noise, timesteps)
820 | else:
821 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
822 |
823 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
824 |
825 | accelerator.backward(loss)
826 |
827 | optimizer.step()
828 | lr_scheduler.step()
829 | optimizer.zero_grad()
830 |
831 | # Let's make sure we don't update any embedding weights besides the newly added token
832 | index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
833 | with torch.no_grad():
834 | accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
835 | index_no_updates
836 | ] = orig_embeds_params[index_no_updates]
837 |
838 | # Checks if the accelerator has performed an optimization step behind the scenes
839 | if accelerator.sync_gradients:
840 | progress_bar.update(1)
841 | global_step += 1
842 | if global_step % args.save_steps == 0:
843 | save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
844 | save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
845 |
846 | if accelerator.is_main_process:
847 | if global_step % args.checkpointing_steps == 0:
848 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
849 | accelerator.save_state(save_path)
850 | logger.info(f"Saved state to {save_path}")
851 |
852 | if args.validation_prompt is not None and global_step % args.validation_steps == 0:
853 | log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
854 |
855 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
856 | progress_bar.set_postfix(**logs)
857 | accelerator.log(logs, step=global_step)
858 |
859 | if global_step >= args.max_train_steps:
860 | break
861 | # Create the pipeline using using the trained modules and save it.
862 | accelerator.wait_for_everyone()
863 | if accelerator.is_main_process:
864 | if args.push_to_hub and args.only_save_embeds:
865 | logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
866 | save_full_model = True
867 | else:
868 | save_full_model = not args.only_save_embeds
869 | if save_full_model:
870 | pipeline = StableDiffusionPipeline.from_pretrained(
871 | args.pretrained_model_name_or_path,
872 | text_encoder=accelerator.unwrap_model(text_encoder),
873 | vae=vae,
874 | unet=unet,
875 | tokenizer=tokenizer,
876 | )
877 | pipeline.save_pretrained(args.output_dir)
878 | # Save the newly trained embeddings
879 | save_path = os.path.join(args.output_dir, "learned_embeds.bin")
880 | save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
881 |
882 | if args.push_to_hub:
883 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
884 |
885 | accelerator.end_training()
886 |
887 |
888 | if __name__ == "__main__":
889 | main()
890 |
--------------------------------------------------------------------------------
/train_p_plus.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import logging
3 | import math
4 | import os
5 | import json
6 | import random
7 | import warnings
8 | from pathlib import Path
9 | from typing import Optional
10 |
11 | import numpy as np
12 | import PIL
13 | import torch
14 | import torch.nn.functional as F
15 | import torch.utils.checkpoint
16 | import transformers
17 | from accelerate import Accelerator
18 | from accelerate.logging import get_logger
19 | from accelerate.utils import ProjectConfiguration, set_seed
20 | from huggingface_hub import HfFolder, Repository, create_repo, whoami
21 |
22 | # TODO: remove and import from diffusers.utils when the new version of diffusers is released
23 | from packaging import version
24 | from PIL import Image
25 | from torch.utils.data import Dataset
26 | from torchvision import transforms
27 | from tqdm.auto import tqdm
28 | from transformers import CLIPTextModel, CLIPTokenizer
29 |
30 | import diffusers
31 | from diffusers import (
32 | AutoencoderKL,
33 | DDPMScheduler,
34 | DiffusionPipeline,
35 | DPMSolverMultistepScheduler,
36 | StableDiffusionPipeline,
37 | )
38 | from diffusers.optimization import get_scheduler
39 | from diffusers.utils import check_min_version, is_wandb_available
40 | from diffusers.utils.import_utils import is_xformers_available
41 | from prompt_plus import PPlusUNet2DConditionModel, PPlusStableDiffusionPipeline
42 |
43 |
44 | if is_wandb_available():
45 | import wandb
46 |
47 | if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
48 | PIL_INTERPOLATION = {
49 | "linear": PIL.Image.Resampling.BILINEAR,
50 | "bilinear": PIL.Image.Resampling.BILINEAR,
51 | "bicubic": PIL.Image.Resampling.BICUBIC,
52 | "lanczos": PIL.Image.Resampling.LANCZOS,
53 | "nearest": PIL.Image.Resampling.NEAREST,
54 | }
55 | else:
56 | PIL_INTERPOLATION = {
57 | "linear": PIL.Image.LINEAR,
58 | "bilinear": PIL.Image.BILINEAR,
59 | "bicubic": PIL.Image.BICUBIC,
60 | "lanczos": PIL.Image.LANCZOS,
61 | "nearest": PIL.Image.NEAREST,
62 | }
63 | # ------------------------------------------------------------------------------
64 |
65 |
66 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
67 | # check_min_version("0.15.0.dev0")
68 |
69 | logger = get_logger(__name__)
70 |
71 |
72 | def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch):
73 | logger.info(
74 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
75 | f" {args.validation_prompt}."
76 | )
77 | # create pipeline (note: unet and vae are loaded again in float32)
78 | pipeline = DiffusionPipeline.from_pretrained(
79 | args.pretrained_model_name_or_path,
80 | text_encoder=accelerator.unwrap_model(text_encoder),
81 | tokenizer=tokenizer,
82 | unet=unet,
83 | vae=vae,
84 | revision=args.revision,
85 | torch_dtype=weight_dtype,
86 | )
87 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
88 | pipeline = pipeline.to(accelerator.device)
89 | pipeline.set_progress_bar_config(disable=True)
90 |
91 | # run inference
92 | generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
93 | images = []
94 | for _ in range(args.num_validation_images):
95 | with torch.autocast("cuda"):
96 | image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
97 | images.append(image)
98 |
99 | for tracker in accelerator.trackers:
100 | if tracker.name == "tensorboard":
101 | np_images = np.stack([np.asarray(img) for img in images])
102 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
103 | if tracker.name == "wandb":
104 | tracker.log(
105 | {
106 | "validation": [
107 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
108 | ]
109 | }
110 | )
111 |
112 | del pipeline
113 | torch.cuda.empty_cache()
114 |
115 |
116 | def save_progress(text_encoder, placeholder_tokens, placeholder_token_ids, accelerator, args, save_path):
117 | logger.info("Saving embeddings")
118 | learned_embeds_dict = dict()
119 | for placeholder_token, placeholder_token_id in zip(placeholder_tokens, placeholder_token_ids):
120 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[placeholder_token_id]
121 | learned_embeds_dict[placeholder_token] = learned_embeds.detach().cpu()
122 | torch.save(learned_embeds_dict, save_path)
123 | with open(os.path.join(os.path.dirname(save_path), "config.json"), "w") as f:
124 | json.dump(args.__dict__, f, indent=2)
125 |
126 |
127 | def parse_args():
128 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
129 | parser.add_argument(
130 | "--save_steps",
131 | type=int,
132 | default=500,
133 | help="Save learned_embeds.bin every X updates steps.",
134 | )
135 | parser.add_argument(
136 | "--only_save_embeds",
137 | action="store_true",
138 | default=False,
139 | help="Save only the embeddings for the new concept.",
140 | )
141 | parser.add_argument(
142 | "--pretrained_model_name_or_path",
143 | type=str,
144 | default=None,
145 | required=True,
146 | help="Path to pretrained model or model identifier from huggingface.co/models.",
147 | )
148 | parser.add_argument(
149 | "--revision",
150 | type=str,
151 | default=None,
152 | required=False,
153 | help="Revision of pretrained model identifier from huggingface.co/models.",
154 | )
155 | parser.add_argument(
156 | "--tokenizer_name",
157 | type=str,
158 | default=None,
159 | help="Pretrained tokenizer name or path if not the same as model_name",
160 | )
161 | parser.add_argument(
162 | "--train_data_dir", type=str, default=None, required=True, help="A folder containing the training data."
163 | )
164 | parser.add_argument(
165 | "--placeholder_token",
166 | type=str,
167 | default=None,
168 | required=True,
169 | help="A token to use as a placeholder for the concept.",
170 | )
171 | parser.add_argument(
172 | "--initializer_token", type=str, default=None, required=True, help="A token to use as initializer word."
173 | )
174 | parser.add_argument("--learnable_property", type=str, default="object", help="Choose between 'object' and 'style'")
175 | parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.")
176 | parser.add_argument(
177 | "--output_dir",
178 | type=str,
179 | default="text-inversion-model",
180 | help="The output directory where the model predictions and checkpoints will be written.",
181 | )
182 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
183 | parser.add_argument(
184 | "--resolution",
185 | type=int,
186 | default=512,
187 | help=(
188 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
189 | " resolution"
190 | ),
191 | )
192 | parser.add_argument(
193 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution."
194 | )
195 | parser.add_argument(
196 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
197 | )
198 | parser.add_argument("--num_train_epochs", type=int, default=100)
199 | parser.add_argument(
200 | "--max_train_steps",
201 | type=int,
202 | default=5000,
203 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
204 | )
205 | parser.add_argument(
206 | "--gradient_accumulation_steps",
207 | type=int,
208 | default=1,
209 | help="Number of updates steps to accumulate before performing a backward/update pass.",
210 | )
211 | parser.add_argument(
212 | "--gradient_checkpointing",
213 | action="store_true",
214 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
215 | )
216 | parser.add_argument(
217 | "--learning_rate",
218 | type=float,
219 | default=1e-4,
220 | help="Initial learning rate (after the potential warmup period) to use.",
221 | )
222 | parser.add_argument(
223 | "--scale_lr",
224 | action="store_true",
225 | default=False,
226 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
227 | )
228 | parser.add_argument(
229 | "--lr_scheduler",
230 | type=str,
231 | default="constant",
232 | help=(
233 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
234 | ' "constant", "constant_with_warmup"]'
235 | ),
236 | )
237 | parser.add_argument(
238 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
239 | )
240 | parser.add_argument(
241 | "--dataloader_num_workers",
242 | type=int,
243 | default=0,
244 | help=(
245 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
246 | ),
247 | )
248 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
249 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
250 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
251 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
252 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
253 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
254 | parser.add_argument(
255 | "--hub_model_id",
256 | type=str,
257 | default=None,
258 | help="The name of the repository to keep in sync with the local `output_dir`.",
259 | )
260 | parser.add_argument(
261 | "--logging_dir",
262 | type=str,
263 | default="logs",
264 | help=(
265 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
266 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
267 | ),
268 | )
269 | parser.add_argument(
270 | "--mixed_precision",
271 | type=str,
272 | default="no",
273 | choices=["no", "fp16", "bf16"],
274 | help=(
275 | "Whether to use mixed precision. Choose"
276 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
277 | "and an Nvidia Ampere GPU."
278 | ),
279 | )
280 | parser.add_argument(
281 | "--allow_tf32",
282 | action="store_true",
283 | help=(
284 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
285 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
286 | ),
287 | )
288 | parser.add_argument(
289 | "--report_to",
290 | type=str,
291 | default="tensorboard",
292 | help=(
293 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
294 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
295 | ),
296 | )
297 | parser.add_argument(
298 | "--validation_prompt",
299 | type=str,
300 | default=None,
301 | help="A prompt that is used during validation to verify that the model is learning.",
302 | )
303 | parser.add_argument(
304 | "--num_validation_images",
305 | type=int,
306 | default=4,
307 | help="Number of images that should be generated during validation with `validation_prompt`.",
308 | )
309 | parser.add_argument(
310 | "--validation_steps",
311 | type=int,
312 | default=100,
313 | help=(
314 | "Run validation every X steps. Validation consists of running the prompt"
315 | " `args.validation_prompt` multiple times: `args.num_validation_images`"
316 | " and logging the images."
317 | ),
318 | )
319 | parser.add_argument(
320 | "--validation_epochs",
321 | type=int,
322 | default=None,
323 | help=(
324 | "Deprecated in favor of validation_steps. Run validation every X epochs. Validation consists of running the prompt"
325 | " `args.validation_prompt` multiple times: `args.num_validation_images`"
326 | " and logging the images."
327 | ),
328 | )
329 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
330 | parser.add_argument(
331 | "--checkpointing_steps",
332 | type=int,
333 | default=500,
334 | help=(
335 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
336 | " training using `--resume_from_checkpoint`."
337 | ),
338 | )
339 | parser.add_argument(
340 | "--checkpoints_total_limit",
341 | type=int,
342 | default=None,
343 | help=(
344 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
345 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
346 | " for more docs"
347 | ),
348 | )
349 | parser.add_argument(
350 | "--resume_from_checkpoint",
351 | type=str,
352 | default=None,
353 | help=(
354 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
355 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
356 | ),
357 | )
358 | parser.add_argument(
359 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
360 | )
361 |
362 | args = parser.parse_args()
363 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
364 | if env_local_rank != -1 and env_local_rank != args.local_rank:
365 | args.local_rank = env_local_rank
366 |
367 | if args.train_data_dir is None:
368 | raise ValueError("You must specify a train data directory.")
369 |
370 | return args
371 |
372 |
373 | imagenet_templates_small = [
374 | "a photo of a {}",
375 | "a rendering of a {}",
376 | "a cropped photo of the {}",
377 | "the photo of a {}",
378 | "a photo of a clean {}",
379 | "a photo of a dirty {}",
380 | "a dark photo of the {}",
381 | "a photo of my {}",
382 | "a photo of the cool {}",
383 | "a close-up photo of a {}",
384 | "a bright photo of the {}",
385 | "a cropped photo of a {}",
386 | "a photo of the {}",
387 | "a good photo of the {}",
388 | "a photo of one {}",
389 | "a close-up photo of the {}",
390 | "a rendition of the {}",
391 | "a photo of the clean {}",
392 | "a rendition of a {}",
393 | "a photo of a nice {}",
394 | "a good photo of a {}",
395 | "a photo of the nice {}",
396 | "a photo of the small {}",
397 | "a photo of the weird {}",
398 | "a photo of the large {}",
399 | "a photo of a cool {}",
400 | "a photo of a small {}",
401 | ]
402 |
403 | imagenet_style_templates_small = [
404 | "a painting in the style of {}",
405 | "a rendering in the style of {}",
406 | "a cropped painting in the style of {}",
407 | "the painting in the style of {}",
408 | "a clean painting in the style of {}",
409 | "a dirty painting in the style of {}",
410 | "a dark painting in the style of {}",
411 | "a picture in the style of {}",
412 | "a cool painting in the style of {}",
413 | "a close-up painting in the style of {}",
414 | "a bright painting in the style of {}",
415 | "a cropped painting in the style of {}",
416 | "a good painting in the style of {}",
417 | "a close-up painting in the style of {}",
418 | "a rendition in the style of {}",
419 | "a nice painting in the style of {}",
420 | "a small painting in the style of {}",
421 | "a weird painting in the style of {}",
422 | "a large painting in the style of {}",
423 | ]
424 |
425 |
426 | class TextualInversionDataset(Dataset):
427 | def __init__(
428 | self,
429 | data_root,
430 | tokenizer,
431 | learnable_property="object", # [object, style]
432 | size=512,
433 | repeats=100,
434 | interpolation="bicubic",
435 | flip_p=0.5,
436 | set="train",
437 | placeholder_tokens=None,
438 | center_crop=False,
439 | ):
440 | assert isinstance(placeholder_tokens, list)
441 | self.data_root = data_root
442 | self.tokenizer = tokenizer
443 | self.learnable_property = learnable_property
444 | self.size = size
445 | self.placeholder_tokens = placeholder_tokens
446 | self.center_crop = center_crop
447 | self.flip_p = flip_p
448 |
449 | self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)]
450 |
451 | self.num_images = len(self.image_paths)
452 | self._length = self.num_images
453 |
454 | if set == "train":
455 | self._length = self.num_images * repeats
456 |
457 | self.interpolation = {
458 | "linear": PIL_INTERPOLATION["linear"],
459 | "bilinear": PIL_INTERPOLATION["bilinear"],
460 | "bicubic": PIL_INTERPOLATION["bicubic"],
461 | "lanczos": PIL_INTERPOLATION["lanczos"],
462 | }[interpolation]
463 |
464 | self.templates = imagenet_style_templates_small if learnable_property == "style" else imagenet_templates_small
465 | self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
466 |
467 | def __len__(self):
468 | return self._length
469 |
470 | def __getitem__(self, i):
471 | example = {}
472 | image = Image.open(self.image_paths[i % self.num_images])
473 |
474 | if not image.mode == "RGB":
475 | image = image.convert("RGB")
476 |
477 | template = random.choice(self.templates)
478 | text = [template.format(placeholder_string) for placeholder_string in self.placeholder_tokens]
479 | example["input_ids"] = self.tokenizer(
480 | text,
481 | padding="max_length",
482 | truncation=True,
483 | max_length=self.tokenizer.model_max_length,
484 | return_tensors="pt",
485 | ).input_ids
486 | # (num_new_tokens, seq_length)
487 |
488 | # default to score-sde preprocessing
489 | img = np.array(image).astype(np.uint8)
490 |
491 | if self.center_crop:
492 | crop = min(img.shape[0], img.shape[1])
493 | (
494 | h,
495 | w,
496 | ) = (
497 | img.shape[0],
498 | img.shape[1],
499 | )
500 | img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2]
501 |
502 | image = Image.fromarray(img)
503 | image = image.resize((self.size, self.size), resample=self.interpolation)
504 |
505 | image = self.flip_transform(image)
506 | image = np.array(image).astype(np.uint8)
507 | image = (image / 127.5 - 1.0).astype(np.float32)
508 |
509 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
510 | return example
511 |
512 |
513 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
514 | if token is None:
515 | token = HfFolder.get_token()
516 | if organization is None:
517 | username = whoami(token)["name"]
518 | return f"{username}/{model_id}"
519 | else:
520 | return f"{organization}/{model_id}"
521 |
522 |
523 | def main():
524 | args = parse_args()
525 | logging_dir = os.path.join(args.output_dir, args.logging_dir)
526 |
527 | accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit)
528 |
529 | accelerator = Accelerator(
530 | gradient_accumulation_steps=args.gradient_accumulation_steps,
531 | mixed_precision=args.mixed_precision,
532 | log_with=args.report_to,
533 | logging_dir=logging_dir,
534 | project_config=accelerator_project_config,
535 | )
536 |
537 | if args.report_to == "wandb":
538 | if not is_wandb_available():
539 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
540 |
541 | # Make one log on every process with the configuration for debugging.
542 | logging.basicConfig(
543 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
544 | datefmt="%m/%d/%Y %H:%M:%S",
545 | level=logging.INFO,
546 | )
547 | logger.info(accelerator.state, main_process_only=False)
548 | if accelerator.is_local_main_process:
549 | transformers.utils.logging.set_verbosity_warning()
550 | diffusers.utils.logging.set_verbosity_info()
551 | else:
552 | transformers.utils.logging.set_verbosity_error()
553 | diffusers.utils.logging.set_verbosity_error()
554 |
555 | # If passed along, set the training seed now.
556 | if args.seed is not None:
557 | set_seed(args.seed)
558 |
559 | # Handle the repository creation
560 | if accelerator.is_main_process:
561 | if args.push_to_hub:
562 | if args.hub_model_id is None:
563 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
564 | else:
565 | repo_name = args.hub_model_id
566 | create_repo(repo_name, exist_ok=True, token=args.hub_token)
567 | repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
568 |
569 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
570 | if "step_*" not in gitignore:
571 | gitignore.write("step_*\n")
572 | if "epoch_*" not in gitignore:
573 | gitignore.write("epoch_*\n")
574 | elif args.output_dir is not None:
575 | os.makedirs(args.output_dir, exist_ok=True)
576 |
577 | # Load tokenizer
578 | if args.tokenizer_name:
579 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
580 | elif args.pretrained_model_name_or_path:
581 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
582 |
583 | # Load scheduler and models
584 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
585 | text_encoder = CLIPTextModel.from_pretrained(
586 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
587 | )
588 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
589 | unet = PPlusUNet2DConditionModel.from_pretrained(
590 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
591 | )
592 |
593 | # Convert the initializer_token, placeholder_token to ids
594 | token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
595 | # Check if initializer_token is a single token or a sequence of tokens
596 | if len(token_ids) > 1:
597 | raise ValueError("The initializer token must be a single token.")
598 |
599 | initializer_token_id = token_ids[0]
600 |
601 | # TODO: more flexible (16 cross attention layers for stable diffusion)
602 | num_cross_attn_layers = 16
603 | # Resize the token embeddings as we are adding new special tokens to the tokenizer
604 | text_encoder.resize_token_embeddings(len(tokenizer)+num_cross_attn_layers)
605 | # Initialise the newly added placeholder token with the embeddings of the initializer token
606 | token_embeds = text_encoder.get_input_embeddings().weight.data
607 | # Add the placeholder token in tokenizer
608 | placeholder_tokens = []
609 | placeholder_token_ids = []
610 | for i in range(num_cross_attn_layers):
611 | placeholder_token = f"{args.placeholder_token}-{i}"
612 | num_added_tokens = tokenizer.add_tokens(placeholder_token)
613 | if num_added_tokens == 0:
614 | raise ValueError(
615 | f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
616 | " `placeholder_token` that is not already in the tokenizer."
617 | )
618 | placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
619 | token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
620 | placeholder_tokens.append(placeholder_token)
621 | placeholder_token_ids.append(placeholder_token_id)
622 |
623 | # Freeze vae and unet
624 | vae.requires_grad_(False)
625 | unet.requires_grad_(False)
626 | # Freeze all parameters except for the token embeddings in text encoder
627 | text_encoder.text_model.encoder.requires_grad_(False)
628 | text_encoder.text_model.final_layer_norm.requires_grad_(False)
629 | text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
630 |
631 | if args.gradient_checkpointing:
632 | # Keep unet in train mode if we are using gradient checkpointing to save memory.
633 | # The dropout cannot be != 0 so it doesn't matter if we are in eval or train mode.
634 | unet.train()
635 | text_encoder.gradient_checkpointing_enable()
636 | unet.enable_gradient_checkpointing()
637 |
638 | if args.enable_xformers_memory_efficient_attention:
639 | if is_xformers_available():
640 | import xformers
641 |
642 | xformers_version = version.parse(xformers.__version__)
643 | if xformers_version == version.parse("0.0.16"):
644 | logger.warn(
645 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
646 | )
647 | unet.enable_xformers_memory_efficient_attention()
648 | else:
649 | raise ValueError("xformers is not available. Make sure it is installed correctly")
650 |
651 | # Enable TF32 for faster training on Ampere GPUs,
652 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
653 | if args.allow_tf32:
654 | torch.backends.cuda.matmul.allow_tf32 = True
655 |
656 | if args.scale_lr:
657 | args.learning_rate = (
658 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
659 | )
660 |
661 | # Initialize the optimizer
662 | optimizer = torch.optim.AdamW(
663 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
664 | lr=args.learning_rate,
665 | betas=(args.adam_beta1, args.adam_beta2),
666 | weight_decay=args.adam_weight_decay,
667 | eps=args.adam_epsilon,
668 | )
669 |
670 | # Dataset and DataLoaders creation:
671 | train_dataset = TextualInversionDataset(
672 | data_root=args.train_data_dir,
673 | tokenizer=tokenizer,
674 | size=args.resolution,
675 | placeholder_tokens=placeholder_tokens,
676 | repeats=args.repeats,
677 | learnable_property=args.learnable_property,
678 | center_crop=args.center_crop,
679 | set="train",
680 | )
681 | train_dataloader = torch.utils.data.DataLoader(
682 | train_dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
683 | )
684 | if args.validation_epochs is not None:
685 | warnings.warn(
686 | f"FutureWarning: You are doing logging with validation_epochs={args.validation_epochs}."
687 | " Deprecated validation_epochs in favor of `validation_steps`"
688 | f"Setting `args.validation_steps` to {args.validation_epochs * len(train_dataset)}",
689 | FutureWarning,
690 | stacklevel=2,
691 | )
692 | args.validation_steps = args.validation_epochs * len(train_dataset)
693 |
694 | # Scheduler and math around the number of training steps.
695 | overrode_max_train_steps = False
696 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
697 | if args.max_train_steps is None:
698 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
699 | overrode_max_train_steps = True
700 |
701 | lr_scheduler = get_scheduler(
702 | args.lr_scheduler,
703 | optimizer=optimizer,
704 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
705 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
706 | )
707 |
708 | # Prepare everything with our `accelerator`.
709 | text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
710 | text_encoder, optimizer, train_dataloader, lr_scheduler
711 | )
712 |
713 | # For mixed precision training we cast the unet and vae weights to half-precision
714 | # as these models are only used for inference, keeping weights in full precision is not required.
715 | weight_dtype = torch.float32
716 | if accelerator.mixed_precision == "fp16":
717 | weight_dtype = torch.float16
718 | elif accelerator.mixed_precision == "bf16":
719 | weight_dtype = torch.bfloat16
720 |
721 | # Move vae and unet to device and cast to weight_dtype
722 | unet.to(accelerator.device, dtype=weight_dtype)
723 | vae.to(accelerator.device, dtype=weight_dtype)
724 |
725 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
726 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
727 | if overrode_max_train_steps:
728 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
729 | # Afterwards we recalculate our number of training epochs
730 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
731 |
732 | # We need to initialize the trackers we use, and also store our configuration.
733 | # The trackers initializes automatically on the main process.
734 | if accelerator.is_main_process:
735 | accelerator.init_trackers("p_plus_xti", config=vars(args))
736 |
737 | # Train!
738 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
739 |
740 | logger.info("***** Running training *****")
741 | logger.info(f" Num examples = {len(train_dataset)}")
742 | logger.info(f" Num Epochs = {args.num_train_epochs}")
743 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
744 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
745 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
746 | logger.info(f" Total optimization steps = {args.max_train_steps}")
747 | global_step = 0
748 | first_epoch = 0
749 | # Potentially load in the weights and states from a previous save
750 | if args.resume_from_checkpoint:
751 | if args.resume_from_checkpoint != "latest":
752 | path = os.path.basename(args.resume_from_checkpoint)
753 | else:
754 | # Get the most recent checkpoint
755 | dirs = os.listdir(args.output_dir)
756 | dirs = [d for d in dirs if d.startswith("checkpoint")]
757 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
758 | path = dirs[-1] if len(dirs) > 0 else None
759 |
760 | if path is None:
761 | accelerator.print(
762 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
763 | )
764 | args.resume_from_checkpoint = None
765 | else:
766 | accelerator.print(f"Resuming from checkpoint {path}")
767 | accelerator.load_state(os.path.join(args.output_dir, path))
768 | global_step = int(path.split("-")[1])
769 |
770 | resume_global_step = global_step * args.gradient_accumulation_steps
771 | first_epoch = global_step // num_update_steps_per_epoch
772 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
773 |
774 | # Only show the progress bar once on each machine.
775 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
776 | progress_bar.set_description("Steps")
777 |
778 | # keep original embeddings as reference
779 | orig_embeds_params = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight.data.clone()
780 |
781 | for epoch in range(first_epoch, args.num_train_epochs):
782 | text_encoder.train()
783 | for step, batch in enumerate(train_dataloader):
784 | # Skip steps until we reach the resumed step
785 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
786 | if step % args.gradient_accumulation_steps == 0:
787 | progress_bar.update(1)
788 | continue
789 |
790 | with accelerator.accumulate(text_encoder):
791 | # Convert images to latent space
792 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
793 | latents = latents * vae.config.scaling_factor
794 |
795 | # Sample noise that we'll add to the latents
796 | noise = torch.randn_like(latents)
797 | bsz = latents.shape[0]
798 | # Sample a random timestep for each image
799 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
800 | timesteps = timesteps.long()
801 |
802 | # Add noise to the latents according to the noise magnitude at each timestep
803 | # (this is the forward diffusion process)
804 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
805 |
806 | # Get the text embedding for conditioning
807 | encoder_hidden_states_list = []
808 | num_new_tokens = batch["input_ids"].size(1)
809 | for i in range(num_new_tokens):
810 | encoder_hidden_states = text_encoder(batch["input_ids"][:, i, :])[0].to(dtype=weight_dtype)
811 | encoder_hidden_states_list.append(encoder_hidden_states)
812 |
813 | # Predict the noise residual
814 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states_list=encoder_hidden_states_list).sample
815 |
816 | # Get the target for loss depending on the prediction type
817 | if noise_scheduler.config.prediction_type == "epsilon":
818 | target = noise
819 | elif noise_scheduler.config.prediction_type == "v_prediction":
820 | target = noise_scheduler.get_velocity(latents, noise, timesteps)
821 | else:
822 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
823 |
824 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
825 |
826 | accelerator.backward(loss)
827 |
828 | optimizer.step()
829 | lr_scheduler.step()
830 | optimizer.zero_grad()
831 |
832 | # Let's make sure we don't update any embedding weights besides the newly added token
833 | vocab = torch.arange(len(tokenizer))
834 | index_no_updates = torch.any(
835 | torch.stack([torch.eq(vocab, aelem).logical_or_(torch.eq(vocab, aelem)) for aelem in placeholder_token_ids], dim=0), dim=0
836 | )
837 | with torch.no_grad():
838 | accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
839 | index_no_updates
840 | ] = orig_embeds_params[index_no_updates]
841 |
842 | # Checks if the accelerator has performed an optimization step behind the scenes
843 | if accelerator.sync_gradients:
844 | progress_bar.update(1)
845 | global_step += 1
846 | if global_step % args.save_steps == 0:
847 | save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
848 | save_progress(text_encoder, placeholder_tokens, placeholder_token_ids, accelerator, args, save_path)
849 |
850 | if accelerator.is_main_process:
851 | if global_step % args.checkpointing_steps == 0:
852 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
853 | accelerator.save_state(save_path)
854 | logger.info(f"Saved state to {save_path}")
855 |
856 | if args.validation_prompt is not None and global_step % args.validation_steps == 0:
857 | log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
858 |
859 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
860 | progress_bar.set_postfix(**logs)
861 | accelerator.log(logs, step=global_step)
862 |
863 | if global_step >= args.max_train_steps:
864 | break
865 | # Create the pipeline using using the trained modules and save it.
866 | accelerator.wait_for_everyone()
867 | if accelerator.is_main_process:
868 | if args.push_to_hub and args.only_save_embeds:
869 | logger.warn("Enabling full model saving because --push_to_hub=True was specified.")
870 | save_full_model = True
871 | else:
872 | save_full_model = not args.only_save_embeds
873 | if save_full_model:
874 | pipeline = StableDiffusionPipeline.from_pretrained(
875 | args.pretrained_model_name_or_path,
876 | text_encoder=accelerator.unwrap_model(text_encoder),
877 | vae=vae,
878 | unet=unet,
879 | tokenizer=tokenizer,
880 | )
881 | pipeline.save_pretrained(args.output_dir)
882 | # Save the newly trained embeddings
883 | save_path = os.path.join(args.output_dir, "learned_embeds.bin")
884 | save_progress(text_encoder, placeholder_tokens, placeholder_token_ids, accelerator, args, save_path)
885 | if args.push_to_hub:
886 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
887 |
888 | accelerator.end_training()
889 |
890 |
891 | if __name__ == "__main__":
892 | main()
893 |
--------------------------------------------------------------------------------