├── imgs
├── en_bird.jpg
├── en_plant.jpg
├── ja_bird.jpg
├── ja_plant.jpg
├── zh_bird.jpg
├── zh_plant.jpg
├── en_blue_dragon.jpg
├── ja_blue_dragon.jpg
└── zh_blue_dragon.jpg
├── requirements.txt
├── install.sh
├── install.md
├── run_en_model.py
├── run_ja_model.py
├── README.md
├── README_EN.md
├── run_zh_model.py
├── train_zh_model.py
├── train_ja_model.py
└── train_en_model.py
/imgs/en_bird.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/en_bird.jpg
--------------------------------------------------------------------------------
/imgs/en_plant.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/en_plant.jpg
--------------------------------------------------------------------------------
/imgs/ja_bird.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/ja_bird.jpg
--------------------------------------------------------------------------------
/imgs/ja_plant.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/ja_plant.jpg
--------------------------------------------------------------------------------
/imgs/zh_bird.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/zh_bird.jpg
--------------------------------------------------------------------------------
/imgs/zh_plant.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/zh_plant.jpg
--------------------------------------------------------------------------------
/imgs/en_blue_dragon.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/en_blue_dragon.jpg
--------------------------------------------------------------------------------
/imgs/ja_blue_dragon.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/ja_blue_dragon.jpg
--------------------------------------------------------------------------------
/imgs/zh_blue_dragon.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/zh_blue_dragon.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers==0.4.1
2 | accelerate
3 | torchvision
4 | transformers>=4.21.0
5 | ftfy
6 | tensorboard
7 | modelcards
8 | jieba
9 | pandas
10 | datasets
11 |
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | pip install git+https://github.com/rinnakk/japanese-stable-diffusion
2 | pip install -r requirements.txt
3 |
4 | sudo apt-get install git-lfs
5 | git clone https://huggingface.co/rinna/japanese-stable-diffusion
6 | git clone https://huggingface.co/CompVis/stable-diffusion-v1-4
7 |
8 | #### pretrained download
9 | git clone https://huggingface.co/svjack/Stable-Diffusion-Pokemon-en
10 | git clone https://huggingface.co/svjack/Stable-Diffusion-Pokemon-ja
11 | git clone https://huggingface.co/svjack/Stable-Diffusion-Pokemon-zh
12 |
--------------------------------------------------------------------------------
/install.md:
--------------------------------------------------------------------------------
1 | pip install git+https://github.com/rinnakk/japanese-stable-diffusion
2 | ####pip install diffusers
3 | huggingface-cli login
4 |
5 | sudo apt-get install git-lfs
6 | git clone https://huggingface.co/rinna/japanese-stable-diffusion
7 |
8 | #### from text_to_image_train.py requirements: https://github.com/huggingface/diffusers/tree/main/examples/text_to_image
9 | git clone https://github.com/huggingface/diffusers
10 | cd examples/text_to_image
11 | pip install -r requirements.txt
12 |
13 | diffusers==0.4.1
14 | accelerate
15 | torchvision
16 | transformers>=4.21.0
17 | ftfy
18 | tensorboard
19 | modelcards
20 | jieba
21 | pandas
22 | datasets
23 |
--------------------------------------------------------------------------------
/run_en_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 |
4 | from torch import autocast
5 | from diffusers import LMSDiscreteScheduler, StableDiffusionPipeline
6 |
7 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012,
8 | beta_schedule="scaled_linear", num_train_timesteps=1000)
9 |
10 | #pretrained_model_name_or_path = "en_model_26000"
11 | pretrained_model_name_or_path = "svjack/Stable-Diffusion-Pokemon-en"
12 | pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path,
13 | scheduler=scheduler, use_auth_token=True)
14 |
15 | pipe = pipe.to("cuda")
16 |
17 | #### disable safety_checker
18 | pipe.safety_checker = lambda images, clip_input: (images, False)
19 |
20 | imgs = pipe("A cartoon character with a potted plant on his head",
21 | num_inference_steps = 100
22 | )
23 | imgs.images[0]
24 |
25 | imgs = pipe("cartoon bird",
26 | num_inference_steps = 100
27 | )
28 | imgs.images[0]
29 |
30 | imgs = pipe("yellow ball",
31 | num_inference_steps = 100
32 | )
33 | imgs.images[0]
34 |
35 | imgs = pipe("blue dragon illustration",
36 | num_inference_steps = 100
37 | )
38 | imgs.images[0]
39 |
40 | ###### person "Zhuge Liang"
41 | ###### penis
42 |
--------------------------------------------------------------------------------
/run_ja_model.py:
--------------------------------------------------------------------------------
1 | from japanese_stable_diffusion import JapaneseStableDiffusionPipeline
2 | import torch
3 | import pandas as pd
4 |
5 | from torch import autocast
6 | from diffusers import LMSDiscreteScheduler
7 |
8 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012,
9 | beta_schedule="scaled_linear", num_train_timesteps=1000)
10 |
11 | #pretrained_model_name_or_path = "jap_model_26000"
12 |
13 | #### sudo apt-get install git-lfs
14 | #### git clone https://huggingface.co/svjack/Stable-Diffusion-Pokemon-ja
15 | pretrained_model_name_or_path = "Stable-Diffusion-Pokemon-ja"
16 |
17 | pipe = JapaneseStableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path,
18 | scheduler=scheduler, use_auth_token=True)
19 |
20 | pipe = pipe.to("cuda")
21 |
22 | #### disable safety_checker
23 | pipe.safety_checker = lambda images, clip_input: (images, False)
24 |
25 | imgs = pipe("鉢植えの植物を頭に載せた漫画のキャラクター",
26 | num_inference_steps = 100
27 | )
28 | imgs.images[0]
29 |
30 | imgs = pipe("漫画の鳥",
31 | num_inference_steps = 100
32 | )
33 | imgs.images[0]
34 |
35 | imgs = pipe("黄色いボール",
36 | num_inference_steps = 100
37 | )
38 | imgs.images[0]
39 |
40 | imgs = pipe("ブルードラゴンのイラスト",
41 | num_inference_steps = 100
42 | )
43 | imgs.images[0]
44 |
45 | ##### 緑のピエロ
46 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
Stable-Diffusion-Pokemon
5 |
6 |
7 | 在宝可梦数据集(Pokemon-Blip-Captions)的英语、日语、中文版本上微调Stable Diffusion的示例
8 |
9 |
10 |
11 |
12 | [In English](README_EN.md)
13 |
14 | ### 简要引述
15 | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release)是现在一流的文本转图片生成模型。
16 | 现如今,借助于提供多模态能力的[diffusers](https://github.com/huggingface/diffusers)工程,人们可以自定义它们的条件或非条件图像(是否以文本提示作为条件)生成模型。
17 | 这个工程的目标是实现diffuser提供的基于[lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions)的[text to image example](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image)微调任务,并将该任务迁移到日文及中文数据集上进行实践。
18 | 通过比较结论将会对Stable Diffusion在不同语言上的微调任务给出指导。
19 | 所有的代码都依据官方的[train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)进行修改,并使得其对于中文和日文是有效的。
20 | 得到的三种语言的训练模型如下:[English](https://huggingface.co/svjack/Stable-Diffusion-Pokemon-en) , [Japanese](https://huggingface.co/svjack/Stable-Diffusion-Pokemon-ja) 及 [Chinese](https://huggingface.co/svjack/Stable-Diffusion-Pokemon-zh).
21 |
22 | ### 安装和运行
23 | 运行install.sh将会安装所有的依赖并下载所有需要的模型(保证您注册并具有huggingface账号以及您的[token](https://huggingface.co/docs/hub/security-tokens))
24 | 下载后可尝试运行[run_en_model.py](run_en_model.py), [run_ja_model.py](run_ja_model.py) 及 [run_zh_model.py](run_zh_model.py)
25 |
26 | ### 数据集准备
27 | 为了在日文及中文领域进行调试,我们需要将[lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions)。我已经使用 [DeepL](https://www.deepl.com/translator) 对其进行翻译并上传至 huggingface dataset hub。分别位于 [svjack/pokemon-blip-captions-en-ja](https://huggingface.co/datasets/svjack/pokemon-blip-captions-en-ja) 和 [svjack/pokemon-blip-captions-en-zh](https://huggingface.co/datasets/svjack/pokemon-blip-captions-en-zh).
28 |
29 |
30 | ### 微调预训练模型
31 | 英文版本是一个仅仅将[train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)的脚本由函数更改为notebook模式,位于[train_en_model.py](train_en_model.py)
32 |
33 | 日文版本使用[rinnakk/japanese-stable-diffusion](https://github.com/rinnakk/japanese-stable-diffusion)替换预训练模型,位于[train_ja_model.py](train_ja_model.py)
34 |
35 | 中文版本使用[IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese](https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese) 替换预训练分词器以及文本编码器,并将BertForTokenClassification的logit输出进行padding来代替CLIPTextModel,位于[train_zh_model.py](train_zh_model.py)。
36 |
37 | 为了在推断阶段使得所有结果可见,我关闭了safety_checker。
38 |
39 | ### 生成器结果比较
40 | Images
41 |
42 |
43 | | Prompt |
44 | English |
45 | Japanese |
46 | Chinese |
47 |
48 |
49 |
50 |
51 | A cartoon character with a potted plant on his head
鉢植えの植物を頭に載せた漫画のキャラクター
一个头上戴着盆栽的卡通人物 |
52 |  |
53 |  |
54 |  |
55 |
56 |
57 | cartoon bird
漫画の鳥
卡通鸟 |
58 |  |
59 |  |
60 |  |
61 |
62 |
63 |
64 |
65 | blue dragon illustration
ブルードラゴンのイラスト
蓝色的龙图 |
66 |  |
67 |  |
68 |  |
69 |
70 |
71 |
72 |
73 | ### 讨论
74 | 在英文、日文、中文下的预训练模型分别训练了26000, 26000 及 20000步。
75 | 对于训练结果的解释是这样的:[rinnakk/japanese-stable-diffusion](https://github.com/rinnakk/japanese-stable-diffusion)由于是日文的原生模型,所以含有很多宝可梦的特征。[Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release)在英文领域微调的很好。[IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese](https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese) 如模型卡片所说在中文上能够起到基本的文本表征作用。
76 |
77 |
78 | ## Contact
79 |
80 |
83 | svjack - svjackbt@gmail.com - ehangzhou@outlook.com
84 |
85 |
88 | Project Link:[https://github.com/svjack/Stable-Diffusion-Pokemon](https://github.com/svjack/Stable-Diffusion-Pokemon)
89 |
90 |
91 |
92 | ## Acknowledgements
93 |
106 | * [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release)
107 | * [diffusers](https://github.com/huggingface/diffusers)
108 | * [DeepL](https://www.deepl.com/translator)
109 | * [rinnakk/japanese-stable-diffusion](https://github.com/rinnakk/japanese-stable-diffusion)
110 | * [IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese](https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese)
111 | * [svjack](https://huggingface.co/svjack)
112 |
--------------------------------------------------------------------------------
/README_EN.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
Stable-Diffusion-Pokemon
5 |
6 |
7 | A demo of fine tune Stable Diffusion on Pokemon-Blip-Captions in English, Japanese and Chinese Corpus
8 |
9 |
10 |
11 |
12 | [中文简介](README.md)
13 |
14 | ### Brief introduction
15 | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) is a state of the art text-to-image model that generates images from text.
16 | Nowadays, with the help of [diffusers](https://github.com/huggingface/diffusers), which provides pretrained diffusion models across multiple modalities, people can customize their own image generator conditional (based on prompt) or unconditional.
17 | This project focus on run the [text to image example](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) diffusers provided based on [lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) and migrate this task to Japanese and Chinese domain
18 | in both model and data dimensions. Compare the conclusions that may give a guideline about the fine tuning of Stable Diffusion in different languages.
19 | All codes are edit versions of the official [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) that make it works in Japanese and Chinese Domain.
20 | And provide three pretrained models in [English](https://huggingface.co/svjack/Stable-Diffusion-Pokemon-en) , [Japanese](https://huggingface.co/svjack/Stable-Diffusion-Pokemon-ja) and [Chinese](https://huggingface.co/svjack/Stable-Diffusion-Pokemon-zh).
21 |
22 | ### Installtion and Running
23 | Running install.sh will install all dependencies and download all models needed.(make sure you have login your huggingface account and have your [token](https://huggingface.co/docs/hub/security-tokens))
24 | After download, you can try [run_en_model.py](run_en_model.py), [run_ja_model.py](run_ja_model.py) and [run_zh_model.py](run_zh_model.py) by yourself.
25 |
26 | ### DataSet prepare
27 | For fine tuning them in Japanese and Chinese domains. All we need is the [lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) in Japanese and Chinese. I have translated them with the help of [DeepL](https://www.deepl.com/translator) and upload them to huggingface dataset hub. Located in [svjack/pokemon-blip-captions-en-ja](https://huggingface.co/datasets/svjack/pokemon-blip-captions-en-ja) and [svjack/pokemon-blip-captions-en-zh](https://huggingface.co/datasets/svjack/pokemon-blip-captions-en-zh).
28 |
29 | ### Fine tuning pretrained models
30 | The English version located in [train_en_model.py](train_en_model.py) is only a simply copy of [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) change the code used by accelerate in script to notebook by function
31 | ```python
32 | notebook_launcher
33 | ```
34 |
35 | The Japanese version located in [train_ja_model.py](train_ja_model.py)
36 | replaced the pretrained model by [rinnakk/japanese-stable-diffusion](https://github.com/rinnakk/japanese-stable-diffusion)
37 |
38 | The Chinese version located in [train_zh_model.py](train_zh_model.py)
39 | replaced the pretrained tokenizer and text_encoder by [IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese](https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese) and use logit output from BertForTokenClassification with padding to the downstream network to replace CLIPTextModel.
40 |
41 | For take a look at all outputs, i disable the safety_checker to let all outputs without covered in the inference steps.
42 |
43 | ### Generator Results comparison
44 | Images
45 |
46 |
47 | | Prompt |
48 | English |
49 | Japanese |
50 | Chinese |
51 |
52 |
53 |
54 |
55 | A cartoon character with a potted plant on his head
鉢植えの植物を頭に載せた漫画のキャラクター
一个头上戴着盆栽的卡通人物 |
56 |  |
57 |  |
58 |  |
59 |
60 |
61 | cartoon bird
漫画の鳥
卡通鸟 |
62 |  |
63 |  |
64 |  |
65 |
66 |
67 |
68 |
69 | blue dragon illustration
ブルードラゴンのイラスト
蓝色的龙图 |
70 |  |
71 |  |
72 |  |
73 |
74 |
75 |
76 |
77 | ### Discussion
78 | The pretrained models in English, Japanese and Chinese are trained for 26000, 26000 and 20000 steps respectively. The Japanese outperform others and the Chinese version seems the third. The interpretation can be
79 | [rinnakk/japanese-stable-diffusion](https://github.com/rinnakk/japanese-stable-diffusion) have many culture and features about Pokemon, [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) in English domain is finetuned favourable. [IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese](https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese) in Chinese as the model card introduction is only a text feature finetuned version.
80 |
81 |
82 | ## Contact
83 |
84 |
87 | svjack - svjackbt@gmail.com - ehangzhou@outlook.com
88 |
89 |
92 | Project Link:[https://github.com/svjack/Stable-Diffusion-Pokemon](https://github.com/svjack/Stable-Diffusion-Pokemon)
93 |
94 |
95 |
96 | ## Acknowledgements
97 |
110 | * [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release)
111 | * [diffusers](https://github.com/huggingface/diffusers)
112 | * [DeepL](https://www.deepl.com/translator)
113 | * [rinnakk/japanese-stable-diffusion](https://github.com/rinnakk/japanese-stable-diffusion)
114 | * [IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese](https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese)
115 | * [svjack](https://huggingface.co/svjack)
116 |
--------------------------------------------------------------------------------
/run_zh_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pandas as pd
3 |
4 | from torch import autocast
5 | from diffusers import LMSDiscreteScheduler
6 |
7 | import torch
8 | from transformers import BertForSequenceClassification, BertConfig, BertTokenizer, BertForTokenClassification
9 | from transformers import CLIPProcessor, CLIPModel
10 | import numpy as np
11 |
12 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import *
13 | from japanese_stable_diffusion.pipeline_stable_diffusion import *
14 |
15 | class StableDiffusionPipelineWrapper(StableDiffusionPipeline):
16 |
17 | @torch.no_grad()
18 | def __call__(
19 | self,
20 | prompt: Union[str, List[str]],
21 | height: int = 512,
22 | width: int = 512,
23 | num_inference_steps: int = 50,
24 | guidance_scale: float = 7.5,
25 | negative_prompt: Optional[Union[str, List[str]]] = None,
26 | num_images_per_prompt: Optional[int] = 1,
27 | eta: float = 0.0,
28 | generator: Optional[torch.Generator] = None,
29 | latents: Optional[torch.FloatTensor] = None,
30 | output_type: Optional[str] = "pil",
31 | return_dict: bool = True,
32 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
33 | callback_steps: Optional[int] = 1,
34 | **kwargs,
35 | ):
36 | if isinstance(prompt, str):
37 | batch_size = 1
38 | elif isinstance(prompt, list):
39 | batch_size = len(prompt)
40 | else:
41 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
42 |
43 | if height % 8 != 0 or width % 8 != 0:
44 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
45 |
46 | if (callback_steps is None) or (
47 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
48 | ):
49 | raise ValueError(
50 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
51 | f" {type(callback_steps)}."
52 | )
53 |
54 | # get prompt text embeddings
55 | text_inputs = self.tokenizer(
56 | prompt,
57 | padding="max_length",
58 | max_length=self.tokenizer.model_max_length,
59 | return_tensors="pt",
60 | )
61 | text_input_ids = text_inputs.input_ids
62 |
63 | if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
64 | removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
65 | logger.warning(
66 | "The following part of your input was truncated because CLIP can only handle sequences up to"
67 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
68 | )
69 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
70 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
71 |
72 | # duplicate text embeddings for each generation per prompt, using mps friendly method
73 | bs_embed, seq_len, _ = text_embeddings.shape
74 | text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
75 | text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
76 |
77 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
78 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
79 | # corresponds to doing no classifier free guidance.
80 | do_classifier_free_guidance = guidance_scale > 1.0
81 | # get unconditional embeddings for classifier free guidance
82 | if do_classifier_free_guidance:
83 | uncond_tokens: List[str]
84 | if negative_prompt is None:
85 | uncond_tokens = [""]
86 | elif type(prompt) is not type(negative_prompt):
87 | raise TypeError(
88 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
89 | f" {type(prompt)}."
90 | )
91 | elif isinstance(negative_prompt, str):
92 | uncond_tokens = [negative_prompt]
93 | elif batch_size != len(negative_prompt):
94 | raise ValueError(
95 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
96 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
97 | " the batch size of `prompt`."
98 | )
99 | else:
100 | uncond_tokens = negative_prompt
101 |
102 | max_length = text_input_ids.shape[-1]
103 | uncond_input = self.tokenizer(
104 | uncond_tokens,
105 | padding="max_length",
106 | max_length=max_length,
107 | truncation=True,
108 | return_tensors="pt",
109 | )
110 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
111 |
112 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
113 | seq_len = uncond_embeddings.shape[1]
114 | uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
115 | uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
116 |
117 | # For classifier free guidance, we need to do two forward passes.
118 | # Here we concatenate the unconditional and text embeddings into a single batch
119 | # to avoid doing two forward passes
120 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
121 |
122 | # get the initial random noise unless the user supplied it
123 |
124 | # Unlike in other pipelines, latents need to be generated in the target device
125 | # for 1-to-1 results reproducibility with the CompVis implementation.
126 | # However this currently doesn't work in `mps`.
127 | latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8)
128 | latents_dtype = text_embeddings.dtype
129 | if latents is None:
130 | if self.device.type == "mps":
131 | # randn does not work reproducibly on mps
132 | latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
133 | self.device
134 | )
135 | else:
136 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype)
137 | else:
138 | if latents.shape != latents_shape:
139 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
140 | latents = latents.to(self.device)
141 |
142 | # set timesteps
143 | self.scheduler.set_timesteps(num_inference_steps)
144 |
145 | # Some schedulers like PNDM have timesteps as arrays
146 | # It's more optimized to move all timesteps to correct device beforehand
147 | timesteps_tensor = self.scheduler.timesteps.to(self.device)
148 |
149 | # scale the initial noise by the standard deviation required by the scheduler
150 | latents = latents * self.scheduler.init_noise_sigma
151 |
152 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
153 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
154 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
155 | # and should be between [0, 1]
156 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
157 | extra_step_kwargs = {}
158 | if accepts_eta:
159 | extra_step_kwargs["eta"] = eta
160 |
161 | for i, t in enumerate(self.progress_bar(timesteps_tensor)):
162 | # expand the latents if we are doing classifier free guidance
163 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
164 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
165 |
166 | # predict the noise residual
167 | ###text_embeddings
168 | #print("before :" ,text_embeddings.shape)
169 | eh_shape = text_embeddings.shape
170 | if i == 0:
171 | eh_pad = torch.zeros((eh_shape[0], eh_shape[1], 768 - 512))
172 | eh_pad = eh_pad.to(self.device)
173 | text_embeddings = torch.concat([text_embeddings, eh_pad], -1)
174 |
175 | #print("after :" ,text_embeddings.shape)
176 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
177 |
178 | # perform guidance
179 | if do_classifier_free_guidance:
180 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
181 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
182 |
183 | # compute the previous noisy sample x_t -> x_t-1
184 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
185 |
186 | # call the callback, if provided
187 | if callback is not None and i % callback_steps == 0:
188 | callback(i, t, latents)
189 |
190 | latents = 1 / 0.18215 * latents
191 | image = self.vae.decode(latents).sample
192 |
193 | image = (image / 2 + 0.5).clamp(0, 1)
194 |
195 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
196 | image = image.cpu().permute(0, 2, 3, 1).float().numpy()
197 |
198 | if self.safety_checker is not None:
199 | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
200 | self.device
201 | )
202 | image, has_nsfw_concept = self.safety_checker(
203 | images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype)
204 | )
205 | else:
206 | has_nsfw_concept = None
207 |
208 | if output_type == "pil":
209 | image = self.numpy_to_pil(image)
210 |
211 | if not return_dict:
212 | return (image, has_nsfw_concept)
213 |
214 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
215 |
216 | if __name__ == "__main__":
217 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012,
218 | beta_schedule="scaled_linear", num_train_timesteps=1000)
219 |
220 | #pretrained_model_name_or_path = "zh_model_20000"
221 | #### sudo apt-get install git-lfs
222 | #### git clone https://huggingface.co/svjack/Stable-Diffusion-Pokemon-zh
223 | pretrained_model_name_or_path = "Stable-Diffusion-Pokemon-zh"
224 |
225 | tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder = "tokenizer")
226 | text_encoder = BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, subfolder = "text_encoder")
227 |
228 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
229 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
230 |
231 | tokenizer.model_max_length = 77
232 | pipeline_wrap = StableDiffusionPipelineWrapper(
233 | text_encoder=text_encoder,
234 | vae=vae,
235 | unet=unet,
236 | tokenizer=tokenizer,
237 | scheduler=scheduler,
238 | safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
239 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
240 | )
241 | pipeline_wrap.safety_checker = lambda images, clip_input: (images, False)
242 | pipeline_wrap = pipeline_wrap.to("cuda")
243 |
244 | imgs = pipeline_wrap("一个头上戴着盆栽的卡通人物",
245 | num_inference_steps = 100
246 | )
247 | imgs.images[0]
248 |
249 | imgs = pipeline_wrap("卡通鸟",
250 | num_inference_steps = 100
251 | )
252 | imgs.images[0]
253 |
254 | imgs = pipeline_wrap("黄色的球",
255 | num_inference_steps = 100
256 | )
257 | imgs.images[0]
258 |
259 | imgs = pipeline_wrap("蓝色的龙图",
260 | num_inference_steps = 100
261 | )
262 | imgs.images[0]
263 |
--------------------------------------------------------------------------------
/train_zh_model.py:
--------------------------------------------------------------------------------
1 | #requires_grad_
2 | REQUIRES_GRAD = False
3 |
4 | import pandas as pd
5 | from collections import namedtuple
6 |
7 | import os
8 | os.environ['KMP_DUPLICATE_LIB_OK']='True'
9 |
10 | import argparse
11 | import logging
12 | import math
13 | import os
14 | import random
15 | from pathlib import Path
16 | from typing import Iterable, Optional
17 |
18 | import numpy as np
19 | import torch
20 | import torch.nn.functional as F
21 | import torch.utils.checkpoint
22 |
23 | from accelerate import Accelerator
24 | from accelerate.logging import get_logger
25 | from accelerate.utils import set_seed
26 | from datasets import load_dataset
27 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
28 | from diffusers.optimization import get_scheduler
29 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
30 | from huggingface_hub import HfFolder, Repository, whoami
31 | from torchvision import transforms
32 | from tqdm.auto import tqdm
33 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
34 |
35 |
36 | logger = get_logger(__name__)
37 |
38 |
39 | def parse_args():
40 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
41 | parser.add_argument(
42 | "--pretrained_model_name_or_path",
43 | type=str,
44 | default=None,
45 | required=True,
46 | help="Path to pretrained model or model identifier from huggingface.co/models.",
47 | )
48 | parser.add_argument(
49 | "--dataset_name",
50 | type=str,
51 | default=None,
52 | help=(
53 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
54 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
55 | " or to a folder containing files that 🤗 Datasets can understand."
56 | ),
57 | )
58 | parser.add_argument(
59 | "--dataset_config_name",
60 | type=str,
61 | default=None,
62 | help="The config of the Dataset, leave as None if there's only one config.",
63 | )
64 | parser.add_argument(
65 | "--train_data_dir",
66 | type=str,
67 | default=None,
68 | help=(
69 | "A folder containing the training data. Folder contents must follow the structure described in"
70 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
71 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
72 | ),
73 | )
74 | parser.add_argument(
75 | "--image_column", type=str, default="image", help="The column of the dataset containing an image."
76 | )
77 | parser.add_argument(
78 | "--caption_column",
79 | type=str,
80 | default="text",
81 | help="The column of the dataset containing a caption or a list of captions.",
82 | )
83 | parser.add_argument(
84 | "--max_train_samples",
85 | type=int,
86 | default=None,
87 | help=(
88 | "For debugging purposes or quicker training, truncate the number of training examples to this "
89 | "value if set."
90 | ),
91 | )
92 | parser.add_argument(
93 | "--output_dir",
94 | type=str,
95 | default="sd-model-finetuned",
96 | help="The output directory where the model predictions and checkpoints will be written.",
97 | )
98 | parser.add_argument(
99 | "--cache_dir",
100 | type=str,
101 | default=None,
102 | help="The directory where the downloaded models and datasets will be stored.",
103 | )
104 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
105 | parser.add_argument(
106 | "--resolution",
107 | type=int,
108 | default=512,
109 | help=(
110 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
111 | " resolution"
112 | ),
113 | )
114 | parser.add_argument(
115 | "--center_crop",
116 | action="store_true",
117 | help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)",
118 | )
119 | parser.add_argument(
120 | "--random_flip",
121 | action="store_true",
122 | help="whether to randomly flip images horizontally",
123 | )
124 | parser.add_argument(
125 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
126 | )
127 | parser.add_argument("--num_train_epochs", type=int, default=100)
128 | parser.add_argument(
129 | "--max_train_steps",
130 | type=int,
131 | default=None,
132 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
133 | )
134 | parser.add_argument(
135 | "--gradient_accumulation_steps",
136 | type=int,
137 | default=1,
138 | help="Number of updates steps to accumulate before performing a backward/update pass.",
139 | )
140 | parser.add_argument(
141 | "--gradient_checkpointing",
142 | action="store_true",
143 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
144 | )
145 | parser.add_argument(
146 | "--learning_rate",
147 | type=float,
148 | default=1e-4,
149 | help="Initial learning rate (after the potential warmup period) to use.",
150 | )
151 | parser.add_argument(
152 | "--scale_lr",
153 | action="store_true",
154 | default=False,
155 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
156 | )
157 | parser.add_argument(
158 | "--lr_scheduler",
159 | type=str,
160 | default="constant",
161 | help=(
162 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
163 | ' "constant", "constant_with_warmup"]'
164 | ),
165 | )
166 | parser.add_argument(
167 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
168 | )
169 | parser.add_argument(
170 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
171 | )
172 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
173 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
174 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
175 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
176 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
177 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
178 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
179 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
180 | parser.add_argument(
181 | "--hub_model_id",
182 | type=str,
183 | default=None,
184 | help="The name of the repository to keep in sync with the local `output_dir`.",
185 | )
186 | parser.add_argument(
187 | "--logging_dir",
188 | type=str,
189 | default="logs",
190 | help=(
191 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
192 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
193 | ),
194 | )
195 | parser.add_argument(
196 | "--mixed_precision",
197 | type=str,
198 | default="no",
199 | choices=["no", "fp16", "bf16"],
200 | help=(
201 | "Whether to use mixed precision. Choose"
202 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
203 | "and an Nvidia Ampere GPU."
204 | ),
205 | )
206 | parser.add_argument(
207 | "--report_to",
208 | type=str,
209 | default="tensorboard",
210 | help=(
211 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
212 | ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
213 | "Only applicable when `--with_tracking` is passed."
214 | ),
215 | )
216 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
217 |
218 | #args = parser.parse_args()
219 | return parser
220 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
221 | if env_local_rank != -1 and env_local_rank != args.local_rank:
222 | args.local_rank = env_local_rank
223 |
224 | # Sanity checks
225 | if args.dataset_name is None and args.train_data_dir is None:
226 | raise ValueError("Need either a dataset name or a training folder.")
227 |
228 | return args
229 |
230 |
231 | def parse_parser_add_arg(parser, as_named_tuple = False):
232 | args_df = pd.DataFrame(
233 | pd.Series(parser.__dict__["_actions"]).map(
234 | lambda x:x.__dict__
235 | ).values.tolist())
236 | args_df = args_df.explode("option_strings")
237 | args_df["option_strings"] = args_df["option_strings"].map(
238 | lambda x: x[2:] if x.startswith("--") else x
239 | ).map(
240 | lambda x: x[1:] if x.startswith("-") else x
241 | )
242 | args_df = args_df[["option_strings", "default"]]
243 | args = dict(args_df.values.tolist())
244 | if as_named_tuple:
245 | args_parser_namedtuple = namedtuple("args_config", args)
246 | return args_parser_namedtuple(**args)
247 | return args_df
248 |
249 | def transform_named_tuple_to_dict(N_tuple):
250 | return dict(map(
251 | lambda x: (x, getattr(N_tuple, x))
252 | ,filter(lambda x: not x.startswith("_") ,dir(N_tuple))
253 | ))
254 |
255 | def transform_dict_to_named_tuple(dict_, name = "NamedTuple"):
256 | args_parser_namedtuple = namedtuple(name, dict_)
257 | return args_parser_namedtuple(**dict_)
258 |
259 | def setattr_gen_option_cls(src_obj):
260 | assert isinstance(src_obj, tuple) or isinstance(src_obj, dict)
261 | if isinstance(src_obj, tuple):
262 | src_obj_ = transform_named_tuple_to_dict(src_obj)
263 | else:
264 | src_obj_ = src_obj
265 | assert isinstance(src_obj_, dict)
266 | class Option(object):
267 | pass
268 | option = Option()
269 | for k, v in src_obj_.items():
270 | setattr(option, k, v)
271 | return option
272 |
273 | args = parse_args()
274 | args = parse_parser_add_arg(args, as_named_tuple = True)
275 | '''
276 | export MODEL_NAME="stable-diffusion-v1-4/"
277 | export dataset_name="lambdalabs/pokemon-blip-captions"
278 |
279 | accelerate launch train_text_to_image_ori.py \
280 | --pretrained_model_name_or_path=$MODEL_NAME \
281 | --dataset_name=$dataset_name \
282 | --use_ema \
283 | --resolution=32 --center_crop --random_flip \
284 | --train_batch_size=1 \
285 | --gradient_accumulation_steps=4 \
286 | --gradient_checkpointing \
287 | --max_train_steps=15000 \
288 | --learning_rate=1e-05 \
289 | --max_grad_norm=1 \
290 | --lr_scheduler="constant" --lr_warmup_steps=0 \
291 | --output_dir="sd-pokemon-model"
292 | '''
293 | args_dict = transform_named_tuple_to_dict(args)
294 | args_dict["pretrained_model_name_or_path"] = "stable-diffusion-v1-4/"
295 | #args_dict["pretrained_model_name_or_path"] = "japanese-stable-diffusion/"
296 | args_dict["dataset_name"] = "svjack/pokemon-blip-captions-en-zh"
297 | args_dict["use_ema"] = True
298 | ###args_dict["use_ema"] = False
299 | args_dict["resolution"] = 256
300 | args_dict["center_crop"] = True
301 | args_dict["random_flip"] = True
302 | args_dict["train_batch_size"] = 1
303 | args_dict["gradient_accumulation_steps"] = 4
304 | args_dict["train_batch_size"] = 4
305 | args_dict["gradient_checkpointing"] = True
306 | #### to 15000
307 | args_dict["max_train_steps"] = 50000
308 | args_dict["learning_rate"] = 1e-05
309 | args_dict["max_grad_norm"] = 1
310 | args_dict["lr_scheduler"] = "constant"
311 | args_dict["lr_warmup_steps"] = 0
312 | args_dict["output_dir"] = "sd-pokemon-model"
313 | args_dict["caption_column"] = "zh_text"
314 | args_dict["mixed_precision"] = "no"
315 | args = transform_dict_to_named_tuple(args_dict)
316 |
317 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
318 | if token is None:
319 | token = HfFolder.get_token()
320 | if organization is None:
321 | username = whoami(token)["name"]
322 | return f"{username}/{model_id}"
323 | else:
324 | return f"{organization}/{model_id}"
325 |
326 | dataset_name_mapping = {
327 | "svjack/pokemon-blip-captions-en-zh": ("image", "zh_text"),
328 | }
329 |
330 |
331 | class EMAModel:
332 | """
333 | Exponential Moving Average of models weights
334 | """
335 |
336 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
337 | parameters = list(parameters)
338 | self.shadow_params = [p.clone().detach() for p in parameters]
339 |
340 | self.decay = decay
341 | self.optimization_step = 0
342 |
343 | def get_decay(self, optimization_step):
344 | """
345 | Compute the decay factor for the exponential moving average.
346 | """
347 | value = (1 + optimization_step) / (10 + optimization_step)
348 | return 1 - min(self.decay, value)
349 |
350 | @torch.no_grad()
351 | def step(self, parameters):
352 | parameters = list(parameters)
353 |
354 | self.optimization_step += 1
355 | self.decay = self.get_decay(self.optimization_step)
356 |
357 | for s_param, param in zip(self.shadow_params, parameters):
358 | if param.requires_grad:
359 | tmp = self.decay * (s_param - param)
360 | s_param.sub_(tmp)
361 | else:
362 | s_param.copy_(param)
363 |
364 | torch.cuda.empty_cache()
365 |
366 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
367 | """
368 | Copy current averaged parameters into given collection of parameters.
369 |
370 | Args:
371 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
372 | updated with the stored moving averages. If `None`, the
373 | parameters with which this `ExponentialMovingAverage` was
374 | initialized will be used.
375 | """
376 | parameters = list(parameters)
377 | for s_param, param in zip(self.shadow_params, parameters):
378 | param.data.copy_(s_param.data)
379 |
380 | def to(self, device=None, dtype=None) -> None:
381 | r"""Move internal buffers of the ExponentialMovingAverage to `device`.
382 |
383 | Args:
384 | device: like `device` argument to `torch.Tensor.to`
385 | """
386 | # .to() on the tensors handles None correctly
387 | self.shadow_params = [
388 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
389 | for p in self.shadow_params
390 | ]
391 |
392 | import torch
393 | from transformers import BertForSequenceClassification, BertConfig, BertTokenizer, BertForTokenClassification
394 | from transformers import CLIPProcessor, CLIPModel
395 | import numpy as np
396 |
397 | tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese")
398 | text_encoder = BertForTokenClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese")
399 |
400 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
401 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
402 |
403 | vae.requires_grad_(REQUIRES_GRAD)
404 | text_encoder.requires_grad_(REQUIRES_GRAD)
405 |
406 | if args.gradient_checkpointing:
407 | unet.enable_gradient_checkpointing()
408 |
409 | if args.scale_lr:
410 | args.learning_rate = (
411 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
412 | )
413 |
414 | # Initialize the optimizer
415 | if args.use_8bit_adam:
416 | try:
417 | import bitsandbytes as bnb
418 | except ImportError:
419 | raise ImportError(
420 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
421 | )
422 |
423 | optimizer_cls = bnb.optim.AdamW8bit
424 | else:
425 | optimizer_cls = torch.optim.AdamW
426 |
427 | optimizer = optimizer_cls(
428 | unet.parameters(),
429 | lr=args.learning_rate,
430 | betas=(args.adam_beta1, args.adam_beta2),
431 | weight_decay=args.adam_weight_decay,
432 | eps=args.adam_epsilon,
433 | )
434 | noise_scheduler = DDPMScheduler(
435 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000,
436 | #tensor_format="pt"
437 | )
438 |
439 |
440 | dataset = load_dataset(
441 | args.dataset_name,
442 | args.dataset_config_name,
443 | cache_dir=args.cache_dir,
444 | )
445 | column_names = dataset["train"].column_names
446 | dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
447 | if args.image_column is None:
448 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
449 | else:
450 | image_column = args.image_column
451 | if image_column not in column_names:
452 | raise ValueError(
453 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
454 | )
455 | if args.caption_column is None:
456 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
457 | else:
458 | caption_column = args.caption_column
459 | if caption_column not in column_names:
460 | raise ValueError(
461 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
462 | )
463 |
464 |
465 | # Preprocessing the datasets.
466 | # We need to tokenize input captions and transform the images.
467 | def tokenize_captions(examples, is_train=True):
468 | captions = []
469 | for caption in examples[caption_column]:
470 | if isinstance(caption, str):
471 | captions.append(caption)
472 | elif isinstance(caption, (list, np.ndarray)):
473 | # take a random caption if there are multiple
474 | captions.append(random.choice(caption) if is_train else caption[0])
475 | else:
476 | raise ValueError(
477 | f"Caption column `{caption_column}` should contain either strings or lists of strings."
478 | )
479 | inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
480 | input_ids = inputs.input_ids
481 | return input_ids
482 |
483 |
484 | train_transforms = transforms.Compose(
485 | [
486 | transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
487 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
488 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
489 | transforms.ToTensor(),
490 | transforms.Normalize([0.5], [0.5]),
491 | ]
492 | )
493 |
494 | def preprocess_train(examples):
495 | images = [image.convert("RGB") for image in examples[image_column]]
496 | examples["pixel_values"] = [train_transforms(image) for image in images]
497 | examples["input_ids"] = tokenize_captions(examples)
498 | return examples
499 |
500 | if args.max_train_samples is not None:
501 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
502 | train_dataset = dataset["train"].with_transform(preprocess_train)
503 |
504 | def collate_fn(examples):
505 | pixel_values = torch.stack([example["pixel_values"] for example in examples])
506 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
507 | input_ids = [example["input_ids"] for example in examples]
508 | padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")
509 | return {
510 | "pixel_values": pixel_values,
511 | "input_ids": padded_tokens.input_ids,
512 | "attention_mask": padded_tokens.attention_mask,
513 | }
514 |
515 | train_dataloader = torch.utils.data.DataLoader(
516 | train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
517 | )
518 |
519 |
520 | overrode_max_train_steps = False
521 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
522 | if args.max_train_steps is None:
523 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
524 | overrode_max_train_steps = True
525 |
526 | lr_scheduler = get_scheduler(
527 | args.lr_scheduler,
528 | optimizer=optimizer,
529 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
530 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
531 | )
532 |
533 | def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
534 | # Initialize accelerator and tensorboard logging
535 | args = config
536 | unet = model
537 |
538 | logging_dir = os.path.join(args.output_dir, args.logging_dir)
539 | accelerator = Accelerator(
540 | gradient_accumulation_steps=args.gradient_accumulation_steps,
541 | mixed_precision=args.mixed_precision,
542 | log_with=args.report_to,
543 | logging_dir=logging_dir,
544 | )
545 | logging.basicConfig(
546 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
547 | datefmt="%m/%d/%Y %H:%M:%S",
548 | level=logging.INFO,
549 | )
550 | if args.seed is not None:
551 | set_seed(args.seed)
552 | if accelerator.is_main_process:
553 | os.makedirs(args.output_dir, exist_ok=True)
554 |
555 | weight_dtype = torch.float32
556 | if args.mixed_precision == "fp16":
557 | weight_dtype = torch.float16
558 | elif args.mixed_precision == "bf16":
559 | weight_dtype = torch.bfloat16
560 |
561 | text_encoder.to(accelerator.device, dtype=weight_dtype)
562 | vae.to(accelerator.device, dtype=weight_dtype)
563 |
564 | if args.use_ema:
565 | ema_unet = EMAModel(unet.parameters())
566 | ema_unet.to(accelerator.device, dtype=weight_dtype)
567 |
568 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
569 | if overrode_max_train_steps:
570 | args_dict = transform_named_tuple_to_dict(args)
571 | args_dict["max_train_steps"] = args.num_train_epochs * num_update_steps_per_epoch
572 | args = transform_dict_to_named_tuple(args_dict)
573 |
574 |
575 | args_dict = transform_named_tuple_to_dict(args)
576 | args_dict["num_train_epochs"] = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
577 | args = transform_dict_to_named_tuple(args_dict)
578 |
579 | if accelerator.is_main_process:
580 | if config.push_to_hub:
581 | #repo = init_git_repo(config, at_init=True)
582 | pass
583 | accelerator.init_trackers("train_example")
584 |
585 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
586 |
587 | logger.info("***** Running training *****")
588 | logger.info(f" Num examples = {len(train_dataset)}")
589 | logger.info(f" Num Epochs = {args.num_train_epochs}")
590 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
591 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
592 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
593 | logger.info(f" Total optimization steps = {args.max_train_steps}")
594 |
595 | # Prepare everything
596 | # There is no specific order to remember, you just need to unpack the
597 | # objects in the same order you gave them to the prepare method.
598 | '''
599 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
600 | model, optimizer, train_dataloader, lr_scheduler
601 | )
602 | '''
603 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
604 | unet, optimizer, train_dataloader, lr_scheduler
605 | )
606 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
607 | progress_bar.set_description("Steps")
608 | global_step = 0
609 |
610 | for epoch in range(args.num_train_epochs):
611 | #unet.train()
612 | train_loss = 0.0
613 | for step, batch in enumerate(train_dataloader):
614 | with accelerator.accumulate(unet):
615 | # Convert images to latent space
616 | latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
617 | latents = latents * 0.18215
618 |
619 | # Sample noise that we'll add to the latents
620 | noise = torch.randn_like(latents)
621 | bsz = latents.shape[0]
622 | # Sample a random timestep for each image
623 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
624 | timesteps = timesteps.long()
625 |
626 | # Add noise to the latents according to the noise magnitude at each timestep
627 | # (this is the forward diffusion process)
628 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
629 |
630 | # Get the text embedding for conditioning
631 | encoder_hidden_states = text_encoder(batch["input_ids"])[0]
632 | #print(encoder_hidden_states.shape)
633 | eh_shape = encoder_hidden_states.shape
634 | eh_pad = torch.zeros((eh_shape[0], eh_shape[1], 768 - 512))
635 | eh_pad = eh_pad.to(accelerator.device)
636 | encoder_hidden_states = torch.concat([encoder_hidden_states, eh_pad], -1)
637 |
638 | # Predict the noise residual and compute loss
639 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
640 |
641 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
642 | #loss = F.mse_loss(noise.float(), noise.float(), reduction="mean")
643 |
644 | # Gather the losses across all processes for logging (if we use distributed training).
645 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
646 | train_loss += avg_loss.item() / args.gradient_accumulation_steps
647 |
648 | # Backpropagate
649 | accelerator.backward(loss)
650 | if accelerator.sync_gradients:
651 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
652 | optimizer.step()
653 | lr_scheduler.step()
654 | optimizer.zero_grad()
655 |
656 | # Checks if the accelerator has performed an optimization step behind the scenes
657 | if accelerator.sync_gradients:
658 | if args.use_ema:
659 | ema_unet.step(unet.parameters())
660 | progress_bar.update(1)
661 | global_step += 1
662 | accelerator.log({"train_loss": train_loss}, step=global_step)
663 | train_loss = 0.0
664 |
665 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
666 | progress_bar.set_postfix(**logs)
667 |
668 | if global_step >= args.max_train_steps:
669 | break
670 |
671 | from accelerate import notebook_launcher
672 |
673 | ####args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
674 | args_ = (args, unet, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
675 | notebook_launcher(train_loop, args_, num_processes=1)
676 |
677 | #### save to local
678 | import os
679 | save_path = "zh_model_20000"
680 | if not os.path.exists(save_path):
681 | os.mkdir(save_path)
682 |
683 | tokenizer.save_pretrained(os.path.join(save_path, "tokenizer"))
684 | text_encoder.save_pretrained(os.path.join(save_path, "text_encoder"))
685 | vae.save_pretrained(os.path.join(save_path, "vae"))
686 | unet.save_pretrained(os.path.join(save_path, "unet"))
687 |
--------------------------------------------------------------------------------
/train_ja_model.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from collections import namedtuple
3 |
4 | import os
5 | os.environ['KMP_DUPLICATE_LIB_OK']='True'
6 |
7 | import argparse
8 | import logging
9 | import math
10 | import os
11 | import random
12 | from pathlib import Path
13 | from typing import Iterable, Optional
14 |
15 | import numpy as np
16 | import torch
17 | import torch.nn.functional as F
18 | import torch.utils.checkpoint
19 |
20 | from accelerate import Accelerator
21 | from accelerate.logging import get_logger
22 | from accelerate.utils import set_seed
23 | from datasets import load_dataset
24 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
25 | from diffusers.optimization import get_scheduler
26 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
27 | from huggingface_hub import HfFolder, Repository, whoami
28 | from torchvision import transforms
29 | from tqdm.auto import tqdm
30 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
31 |
32 |
33 | logger = get_logger(__name__)
34 |
35 |
36 | def parse_args():
37 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
38 | parser.add_argument(
39 | "--pretrained_model_name_or_path",
40 | type=str,
41 | default=None,
42 | required=True,
43 | help="Path to pretrained model or model identifier from huggingface.co/models.",
44 | )
45 | parser.add_argument(
46 | "--dataset_name",
47 | type=str,
48 | default=None,
49 | help=(
50 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
51 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
52 | " or to a folder containing files that 🤗 Datasets can understand."
53 | ),
54 | )
55 | parser.add_argument(
56 | "--dataset_config_name",
57 | type=str,
58 | default=None,
59 | help="The config of the Dataset, leave as None if there's only one config.",
60 | )
61 | parser.add_argument(
62 | "--train_data_dir",
63 | type=str,
64 | default=None,
65 | help=(
66 | "A folder containing the training data. Folder contents must follow the structure described in"
67 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
68 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
69 | ),
70 | )
71 | parser.add_argument(
72 | "--image_column", type=str, default="image", help="The column of the dataset containing an image."
73 | )
74 | parser.add_argument(
75 | "--caption_column",
76 | type=str,
77 | default="text",
78 | help="The column of the dataset containing a caption or a list of captions.",
79 | )
80 | parser.add_argument(
81 | "--max_train_samples",
82 | type=int,
83 | default=None,
84 | help=(
85 | "For debugging purposes or quicker training, truncate the number of training examples to this "
86 | "value if set."
87 | ),
88 | )
89 | parser.add_argument(
90 | "--output_dir",
91 | type=str,
92 | default="sd-model-finetuned",
93 | help="The output directory where the model predictions and checkpoints will be written.",
94 | )
95 | parser.add_argument(
96 | "--cache_dir",
97 | type=str,
98 | default=None,
99 | help="The directory where the downloaded models and datasets will be stored.",
100 | )
101 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
102 | parser.add_argument(
103 | "--resolution",
104 | type=int,
105 | default=512,
106 | help=(
107 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
108 | " resolution"
109 | ),
110 | )
111 | parser.add_argument(
112 | "--center_crop",
113 | action="store_true",
114 | help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)",
115 | )
116 | parser.add_argument(
117 | "--random_flip",
118 | action="store_true",
119 | help="whether to randomly flip images horizontally",
120 | )
121 | parser.add_argument(
122 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
123 | )
124 | parser.add_argument("--num_train_epochs", type=int, default=100)
125 | parser.add_argument(
126 | "--max_train_steps",
127 | type=int,
128 | default=None,
129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
130 | )
131 | parser.add_argument(
132 | "--gradient_accumulation_steps",
133 | type=int,
134 | default=1,
135 | help="Number of updates steps to accumulate before performing a backward/update pass.",
136 | )
137 | parser.add_argument(
138 | "--gradient_checkpointing",
139 | action="store_true",
140 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
141 | )
142 | parser.add_argument(
143 | "--learning_rate",
144 | type=float,
145 | default=1e-4,
146 | help="Initial learning rate (after the potential warmup period) to use.",
147 | )
148 | parser.add_argument(
149 | "--scale_lr",
150 | action="store_true",
151 | default=False,
152 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
153 | )
154 | parser.add_argument(
155 | "--lr_scheduler",
156 | type=str,
157 | default="constant",
158 | help=(
159 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
160 | ' "constant", "constant_with_warmup"]'
161 | ),
162 | )
163 | parser.add_argument(
164 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
165 | )
166 | parser.add_argument(
167 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
168 | )
169 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
170 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
171 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
172 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
173 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
174 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
175 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
176 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
177 | parser.add_argument(
178 | "--hub_model_id",
179 | type=str,
180 | default=None,
181 | help="The name of the repository to keep in sync with the local `output_dir`.",
182 | )
183 | parser.add_argument(
184 | "--logging_dir",
185 | type=str,
186 | default="logs",
187 | help=(
188 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
189 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
190 | ),
191 | )
192 | parser.add_argument(
193 | "--mixed_precision",
194 | type=str,
195 | default="no",
196 | choices=["no", "fp16", "bf16"],
197 | help=(
198 | "Whether to use mixed precision. Choose"
199 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
200 | "and an Nvidia Ampere GPU."
201 | ),
202 | )
203 | parser.add_argument(
204 | "--report_to",
205 | type=str,
206 | default="tensorboard",
207 | help=(
208 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
209 | ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
210 | "Only applicable when `--with_tracking` is passed."
211 | ),
212 | )
213 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
214 |
215 | #args = parser.parse_args()
216 | return parser
217 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
218 | if env_local_rank != -1 and env_local_rank != args.local_rank:
219 | args.local_rank = env_local_rank
220 |
221 | # Sanity checks
222 | if args.dataset_name is None and args.train_data_dir is None:
223 | raise ValueError("Need either a dataset name or a training folder.")
224 |
225 | return args
226 |
227 |
228 | def parse_parser_add_arg(parser, as_named_tuple = False):
229 | args_df = pd.DataFrame(
230 | pd.Series(parser.__dict__["_actions"]).map(
231 | lambda x:x.__dict__
232 | ).values.tolist())
233 | args_df = args_df.explode("option_strings")
234 | args_df["option_strings"] = args_df["option_strings"].map(
235 | lambda x: x[2:] if x.startswith("--") else x
236 | ).map(
237 | lambda x: x[1:] if x.startswith("-") else x
238 | )
239 | args_df = args_df[["option_strings", "default"]]
240 | args = dict(args_df.values.tolist())
241 | if as_named_tuple:
242 | args_parser_namedtuple = namedtuple("args_config", args)
243 | return args_parser_namedtuple(**args)
244 | return args_df
245 |
246 | def transform_named_tuple_to_dict(N_tuple):
247 | return dict(map(
248 | lambda x: (x, getattr(N_tuple, x))
249 | ,filter(lambda x: not x.startswith("_") ,dir(N_tuple))
250 | ))
251 |
252 | def transform_dict_to_named_tuple(dict_, name = "NamedTuple"):
253 | args_parser_namedtuple = namedtuple(name, dict_)
254 | return args_parser_namedtuple(**dict_)
255 |
256 | def setattr_gen_option_cls(src_obj):
257 | assert isinstance(src_obj, tuple) or isinstance(src_obj, dict)
258 | if isinstance(src_obj, tuple):
259 | src_obj_ = transform_named_tuple_to_dict(src_obj)
260 | else:
261 | src_obj_ = src_obj
262 | assert isinstance(src_obj_, dict)
263 | class Option(object):
264 | pass
265 | option = Option()
266 | for k, v in src_obj_.items():
267 | setattr(option, k, v)
268 | return option
269 |
270 | args = parse_args()
271 | args = parse_parser_add_arg(args, as_named_tuple = True)
272 | '''
273 | export MODEL_NAME="stable-diffusion-v1-4/"
274 | export dataset_name="lambdalabs/pokemon-blip-captions"
275 |
276 | accelerate launch train_text_to_image_ori.py \
277 | --pretrained_model_name_or_path=$MODEL_NAME \
278 | --dataset_name=$dataset_name \
279 | --use_ema \
280 | --resolution=32 --center_crop --random_flip \
281 | --train_batch_size=1 \
282 | --gradient_accumulation_steps=4 \
283 | --gradient_checkpointing \
284 | --max_train_steps=15000 \
285 | --learning_rate=1e-05 \
286 | --max_grad_norm=1 \
287 | --lr_scheduler="constant" --lr_warmup_steps=0 \
288 | --output_dir="sd-pokemon-model"
289 | '''
290 | args_dict = transform_named_tuple_to_dict(args)
291 | #args_dict["pretrained_model_name_or_path"] = "stable-diffusion-v1-4/"
292 | args_dict["pretrained_model_name_or_path"] = "japanese-stable-diffusion/"
293 | args_dict["dataset_name"] = "svjack/pokemon-blip-captions-en-ja"
294 | args_dict["use_ema"] = True
295 | ###args_dict["use_ema"] = False
296 | args_dict["resolution"] = 256
297 | args_dict["center_crop"] = True
298 | args_dict["random_flip"] = True
299 | args_dict["train_batch_size"] = 1
300 | args_dict["gradient_accumulation_steps"] = 4
301 | args_dict["train_batch_size"] = 4
302 | args_dict["gradient_checkpointing"] = True
303 | #### to 15000
304 | args_dict["max_train_steps"] = 50000
305 | args_dict["learning_rate"] = 1e-05
306 | args_dict["max_grad_norm"] = 1
307 | args_dict["lr_scheduler"] = "constant"
308 | args_dict["lr_warmup_steps"] = 0
309 | args_dict["output_dir"] = "sd-pokemon-model"
310 | args_dict["caption_column"] = "ja_text"
311 | args_dict["mixed_precision"] = "no"
312 | args = transform_dict_to_named_tuple(args_dict)
313 |
314 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
315 | if token is None:
316 | token = HfFolder.get_token()
317 | if organization is None:
318 | username = whoami(token)["name"]
319 | return f"{username}/{model_id}"
320 | else:
321 | return f"{organization}/{model_id}"
322 |
323 | dataset_name_mapping = {
324 | "svjack/pokemon-blip-captions-en-ja": ("image", "ja_text"),
325 | }
326 |
327 |
328 | class EMAModel:
329 | """
330 | Exponential Moving Average of models weights
331 | """
332 |
333 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
334 | parameters = list(parameters)
335 | self.shadow_params = [p.clone().detach() for p in parameters]
336 |
337 | self.decay = decay
338 | self.optimization_step = 0
339 |
340 | def get_decay(self, optimization_step):
341 | """
342 | Compute the decay factor for the exponential moving average.
343 | """
344 | value = (1 + optimization_step) / (10 + optimization_step)
345 | return 1 - min(self.decay, value)
346 |
347 | @torch.no_grad()
348 | def step(self, parameters):
349 | parameters = list(parameters)
350 |
351 | self.optimization_step += 1
352 | self.decay = self.get_decay(self.optimization_step)
353 |
354 | for s_param, param in zip(self.shadow_params, parameters):
355 | if param.requires_grad:
356 | tmp = self.decay * (s_param - param)
357 | s_param.sub_(tmp)
358 | else:
359 | s_param.copy_(param)
360 |
361 | torch.cuda.empty_cache()
362 |
363 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
364 | """
365 | Copy current averaged parameters into given collection of parameters.
366 |
367 | Args:
368 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
369 | updated with the stored moving averages. If `None`, the
370 | parameters with which this `ExponentialMovingAverage` was
371 | initialized will be used.
372 | """
373 | parameters = list(parameters)
374 | for s_param, param in zip(self.shadow_params, parameters):
375 | param.data.copy_(s_param.data)
376 |
377 | def to(self, device=None, dtype=None) -> None:
378 | r"""Move internal buffers of the ExponentialMovingAverage to `device`.
379 |
380 | Args:
381 | device: like `device` argument to `torch.Tensor.to`
382 | """
383 | # .to() on the tensors handles None correctly
384 | self.shadow_params = [
385 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
386 | for p in self.shadow_params
387 | ]
388 |
389 | #### model prepare
390 | from japanese_stable_diffusion import JapaneseStableDiffusionPipeline
391 | import torch
392 | import pandas as pd
393 |
394 | from torch import autocast
395 | from diffusers import LMSDiscreteScheduler
396 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012,
397 | beta_schedule="scaled_linear", num_train_timesteps=1000)
398 |
399 | # pretrained_model_name_or_path = "jap_to_zh_35000"
400 | pretrained_model_name_or_path = "japanese-stable-diffusion/"
401 | pipe = JapaneseStableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path,
402 | scheduler=scheduler, use_auth_token=True)
403 | tokenizer, text_encoder, vae, unet = pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet
404 |
405 | vae.requires_grad_(False)
406 | text_encoder.requires_grad_(False)
407 |
408 | if args.gradient_checkpointing:
409 | unet.enable_gradient_checkpointing()
410 |
411 | if args.scale_lr:
412 | args.learning_rate = (
413 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
414 | )
415 |
416 | # Initialize the optimizer
417 | if args.use_8bit_adam:
418 | try:
419 | import bitsandbytes as bnb
420 | except ImportError:
421 | raise ImportError(
422 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
423 | )
424 |
425 | optimizer_cls = bnb.optim.AdamW8bit
426 | else:
427 | optimizer_cls = torch.optim.AdamW
428 |
429 | optimizer = optimizer_cls(
430 | unet.parameters(),
431 | lr=args.learning_rate,
432 | betas=(args.adam_beta1, args.adam_beta2),
433 | weight_decay=args.adam_weight_decay,
434 | eps=args.adam_epsilon,
435 | )
436 | noise_scheduler = DDPMScheduler(
437 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000,
438 | #tensor_format="pt"
439 | )
440 |
441 |
442 | dataset = load_dataset(
443 | args.dataset_name,
444 | args.dataset_config_name,
445 | cache_dir=args.cache_dir,
446 | )
447 | column_names = dataset["train"].column_names
448 | dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
449 | if args.image_column is None:
450 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
451 | else:
452 | image_column = args.image_column
453 | if image_column not in column_names:
454 | raise ValueError(
455 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
456 | )
457 | if args.caption_column is None:
458 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
459 | else:
460 | caption_column = args.caption_column
461 | if caption_column not in column_names:
462 | raise ValueError(
463 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
464 | )
465 |
466 |
467 | # Preprocessing the datasets.
468 | # We need to tokenize input captions and transform the images.
469 | def tokenize_captions(examples, is_train=True):
470 | captions = []
471 | for caption in examples[caption_column]:
472 | if isinstance(caption, str):
473 | captions.append(caption)
474 | elif isinstance(caption, (list, np.ndarray)):
475 | # take a random caption if there are multiple
476 | captions.append(random.choice(caption) if is_train else caption[0])
477 | else:
478 | raise ValueError(
479 | f"Caption column `{caption_column}` should contain either strings or lists of strings."
480 | )
481 | inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
482 | input_ids = inputs.input_ids
483 | return input_ids
484 |
485 |
486 | train_transforms = transforms.Compose(
487 | [
488 | transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
489 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
490 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
491 | transforms.ToTensor(),
492 | transforms.Normalize([0.5], [0.5]),
493 | ]
494 | )
495 |
496 | def preprocess_train(examples):
497 | images = [image.convert("RGB") for image in examples[image_column]]
498 | examples["pixel_values"] = [train_transforms(image) for image in images]
499 | examples["input_ids"] = tokenize_captions(examples)
500 | return examples
501 |
502 | if args.max_train_samples is not None:
503 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
504 | train_dataset = dataset["train"].with_transform(preprocess_train)
505 |
506 | def collate_fn(examples):
507 | pixel_values = torch.stack([example["pixel_values"] for example in examples])
508 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
509 | input_ids = [example["input_ids"] for example in examples]
510 | padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")
511 | return {
512 | "pixel_values": pixel_values,
513 | "input_ids": padded_tokens.input_ids,
514 | "attention_mask": padded_tokens.attention_mask,
515 | }
516 |
517 | train_dataloader = torch.utils.data.DataLoader(
518 | train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
519 | )
520 |
521 |
522 | overrode_max_train_steps = False
523 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
524 | if args.max_train_steps is None:
525 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
526 | overrode_max_train_steps = True
527 |
528 | lr_scheduler = get_scheduler(
529 | args.lr_scheduler,
530 | optimizer=optimizer,
531 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
532 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
533 | )
534 |
535 | def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
536 | # Initialize accelerator and tensorboard logging
537 | args = config
538 | unet = model
539 |
540 | logging_dir = os.path.join(args.output_dir, args.logging_dir)
541 | accelerator = Accelerator(
542 | gradient_accumulation_steps=args.gradient_accumulation_steps,
543 | mixed_precision=args.mixed_precision,
544 | log_with=args.report_to,
545 | logging_dir=logging_dir,
546 | )
547 | logging.basicConfig(
548 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
549 | datefmt="%m/%d/%Y %H:%M:%S",
550 | level=logging.INFO,
551 | )
552 | if args.seed is not None:
553 | set_seed(args.seed)
554 | if accelerator.is_main_process:
555 | os.makedirs(args.output_dir, exist_ok=True)
556 |
557 | weight_dtype = torch.float32
558 | if args.mixed_precision == "fp16":
559 | weight_dtype = torch.float16
560 | elif args.mixed_precision == "bf16":
561 | weight_dtype = torch.bfloat16
562 |
563 | text_encoder.to(accelerator.device, dtype=weight_dtype)
564 | vae.to(accelerator.device, dtype=weight_dtype)
565 |
566 | if args.use_ema:
567 | ema_unet = EMAModel(unet.parameters())
568 | ema_unet.to(accelerator.device, dtype=weight_dtype)
569 |
570 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
571 | if overrode_max_train_steps:
572 | args_dict = transform_named_tuple_to_dict(args)
573 | args_dict["max_train_steps"] = args.num_train_epochs * num_update_steps_per_epoch
574 | args = transform_dict_to_named_tuple(args_dict)
575 |
576 |
577 | args_dict = transform_named_tuple_to_dict(args)
578 | args_dict["num_train_epochs"] = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
579 | args = transform_dict_to_named_tuple(args_dict)
580 |
581 | if accelerator.is_main_process:
582 | if config.push_to_hub:
583 | #repo = init_git_repo(config, at_init=True)
584 | pass
585 | accelerator.init_trackers("train_example")
586 |
587 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
588 |
589 | logger.info("***** Running training *****")
590 | logger.info(f" Num examples = {len(train_dataset)}")
591 | logger.info(f" Num Epochs = {args.num_train_epochs}")
592 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
593 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
594 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
595 | logger.info(f" Total optimization steps = {args.max_train_steps}")
596 |
597 | # Prepare everything
598 | # There is no specific order to remember, you just need to unpack the
599 | # objects in the same order you gave them to the prepare method.
600 | '''
601 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
602 | model, optimizer, train_dataloader, lr_scheduler
603 | )
604 | '''
605 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
606 | unet, optimizer, train_dataloader, lr_scheduler
607 | )
608 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
609 | progress_bar.set_description("Steps")
610 | global_step = 0
611 |
612 | for epoch in range(args.num_train_epochs):
613 | #unet.train()
614 | train_loss = 0.0
615 | for step, batch in enumerate(train_dataloader):
616 | with accelerator.accumulate(unet):
617 | # Convert images to latent space
618 | latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
619 | latents = latents * 0.18215
620 |
621 | # Sample noise that we'll add to the latents
622 | noise = torch.randn_like(latents)
623 | bsz = latents.shape[0]
624 | # Sample a random timestep for each image
625 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
626 | timesteps = timesteps.long()
627 |
628 | # Add noise to the latents according to the noise magnitude at each timestep
629 | # (this is the forward diffusion process)
630 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
631 |
632 | # Get the text embedding for conditioning
633 | encoder_hidden_states = text_encoder(batch["input_ids"])[0]
634 |
635 | # Predict the noise residual and compute loss
636 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
637 |
638 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
639 | #loss = F.mse_loss(noise.float(), noise.float(), reduction="mean")
640 |
641 | # Gather the losses across all processes for logging (if we use distributed training).
642 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
643 | train_loss += avg_loss.item() / args.gradient_accumulation_steps
644 |
645 | # Backpropagate
646 | accelerator.backward(loss)
647 | if accelerator.sync_gradients:
648 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
649 | optimizer.step()
650 | lr_scheduler.step()
651 | optimizer.zero_grad()
652 |
653 | # Checks if the accelerator has performed an optimization step behind the scenes
654 | if accelerator.sync_gradients:
655 | if args.use_ema:
656 | ema_unet.step(unet.parameters())
657 | progress_bar.update(1)
658 | global_step += 1
659 | accelerator.log({"train_loss": train_loss}, step=global_step)
660 | train_loss = 0.0
661 |
662 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
663 | progress_bar.set_postfix(**logs)
664 |
665 | if global_step >= args.max_train_steps:
666 | break
667 |
668 | from accelerate import notebook_launcher
669 |
670 | #### train it.
671 | args_ = (args, unet, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
672 | notebook_launcher(train_loop, args_, num_processes=1)
673 |
674 | #### save to local
675 | pipeline = StableDiffusionPipeline(
676 | text_encoder=text_encoder,
677 | vae=vae,
678 | unet=unet,
679 | tokenizer=tokenizer,
680 | scheduler=scheduler,
681 | safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
682 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
683 | )
684 | pipeline.save_pretrained("jap_model_50000")
685 | pipeline.safety_checker = lambda images, clip_input: (images, False)
686 | pipeline = pipeline.to("cuda")
687 |
688 | imgs = pipeline("鉢植えの植物を頭に載せた漫画のキャラクター",
689 | num_inference_steps = 100
690 | )
691 | imgs.images[0]
692 |
693 | imgs = pipeline("漫画の鳥",
694 | num_inference_steps = 100
695 | )
696 | imgs.images[0]
697 |
698 | imgs = pipeline("黄色いボール",
699 | num_inference_steps = 100
700 | )
701 | imgs.images[0]
702 |
703 | imgs = pipeline("ブルードラゴンのイラスト",
704 | num_inference_steps = 100
705 | )
706 | imgs.images[0]
707 |
--------------------------------------------------------------------------------
/train_en_model.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | from collections import namedtuple
3 |
4 | import os
5 | os.environ['KMP_DUPLICATE_LIB_OK']='True'
6 |
7 | import argparse
8 | import logging
9 | import math
10 | import os
11 | import random
12 | from pathlib import Path
13 | from typing import Iterable, Optional
14 |
15 | import numpy as np
16 | import torch
17 | import torch.nn.functional as F
18 | import torch.utils.checkpoint
19 |
20 | from accelerate import Accelerator
21 | from accelerate.logging import get_logger
22 | from accelerate.utils import set_seed
23 | from datasets import load_dataset
24 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
25 | from diffusers.optimization import get_scheduler
26 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
27 | from huggingface_hub import HfFolder, Repository, whoami
28 | from torchvision import transforms
29 | from tqdm.auto import tqdm
30 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
31 |
32 |
33 | logger = get_logger(__name__)
34 |
35 |
36 | def parse_args():
37 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
38 | parser.add_argument(
39 | "--pretrained_model_name_or_path",
40 | type=str,
41 | default=None,
42 | required=True,
43 | help="Path to pretrained model or model identifier from huggingface.co/models.",
44 | )
45 | parser.add_argument(
46 | "--dataset_name",
47 | type=str,
48 | default=None,
49 | help=(
50 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
51 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
52 | " or to a folder containing files that 🤗 Datasets can understand."
53 | ),
54 | )
55 | parser.add_argument(
56 | "--dataset_config_name",
57 | type=str,
58 | default=None,
59 | help="The config of the Dataset, leave as None if there's only one config.",
60 | )
61 | parser.add_argument(
62 | "--train_data_dir",
63 | type=str,
64 | default=None,
65 | help=(
66 | "A folder containing the training data. Folder contents must follow the structure described in"
67 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
68 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
69 | ),
70 | )
71 | parser.add_argument(
72 | "--image_column", type=str, default="image", help="The column of the dataset containing an image."
73 | )
74 | parser.add_argument(
75 | "--caption_column",
76 | type=str,
77 | default="text",
78 | help="The column of the dataset containing a caption or a list of captions.",
79 | )
80 | parser.add_argument(
81 | "--max_train_samples",
82 | type=int,
83 | default=None,
84 | help=(
85 | "For debugging purposes or quicker training, truncate the number of training examples to this "
86 | "value if set."
87 | ),
88 | )
89 | parser.add_argument(
90 | "--output_dir",
91 | type=str,
92 | default="sd-model-finetuned",
93 | help="The output directory where the model predictions and checkpoints will be written.",
94 | )
95 | parser.add_argument(
96 | "--cache_dir",
97 | type=str,
98 | default=None,
99 | help="The directory where the downloaded models and datasets will be stored.",
100 | )
101 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
102 | parser.add_argument(
103 | "--resolution",
104 | type=int,
105 | default=512,
106 | help=(
107 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
108 | " resolution"
109 | ),
110 | )
111 | parser.add_argument(
112 | "--center_crop",
113 | action="store_true",
114 | help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)",
115 | )
116 | parser.add_argument(
117 | "--random_flip",
118 | action="store_true",
119 | help="whether to randomly flip images horizontally",
120 | )
121 | parser.add_argument(
122 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
123 | )
124 | parser.add_argument("--num_train_epochs", type=int, default=100)
125 | parser.add_argument(
126 | "--max_train_steps",
127 | type=int,
128 | default=None,
129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
130 | )
131 | parser.add_argument(
132 | "--gradient_accumulation_steps",
133 | type=int,
134 | default=1,
135 | help="Number of updates steps to accumulate before performing a backward/update pass.",
136 | )
137 | parser.add_argument(
138 | "--gradient_checkpointing",
139 | action="store_true",
140 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
141 | )
142 | parser.add_argument(
143 | "--learning_rate",
144 | type=float,
145 | default=1e-4,
146 | help="Initial learning rate (after the potential warmup period) to use.",
147 | )
148 | parser.add_argument(
149 | "--scale_lr",
150 | action="store_true",
151 | default=False,
152 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
153 | )
154 | parser.add_argument(
155 | "--lr_scheduler",
156 | type=str,
157 | default="constant",
158 | help=(
159 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
160 | ' "constant", "constant_with_warmup"]'
161 | ),
162 | )
163 | parser.add_argument(
164 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
165 | )
166 | parser.add_argument(
167 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
168 | )
169 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
170 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
171 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
172 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
173 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
174 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
175 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
176 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
177 | parser.add_argument(
178 | "--hub_model_id",
179 | type=str,
180 | default=None,
181 | help="The name of the repository to keep in sync with the local `output_dir`.",
182 | )
183 | parser.add_argument(
184 | "--logging_dir",
185 | type=str,
186 | default="logs",
187 | help=(
188 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
189 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
190 | ),
191 | )
192 | parser.add_argument(
193 | "--mixed_precision",
194 | type=str,
195 | default="no",
196 | choices=["no", "fp16", "bf16"],
197 | help=(
198 | "Whether to use mixed precision. Choose"
199 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
200 | "and an Nvidia Ampere GPU."
201 | ),
202 | )
203 | parser.add_argument(
204 | "--report_to",
205 | type=str,
206 | default="tensorboard",
207 | help=(
208 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
209 | ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
210 | "Only applicable when `--with_tracking` is passed."
211 | ),
212 | )
213 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
214 |
215 | #args = parser.parse_args()
216 | return parser
217 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
218 | if env_local_rank != -1 and env_local_rank != args.local_rank:
219 | args.local_rank = env_local_rank
220 |
221 | # Sanity checks
222 | if args.dataset_name is None and args.train_data_dir is None:
223 | raise ValueError("Need either a dataset name or a training folder.")
224 |
225 | return args
226 |
227 |
228 | def parse_parser_add_arg(parser, as_named_tuple = False):
229 | args_df = pd.DataFrame(
230 | pd.Series(parser.__dict__["_actions"]).map(
231 | lambda x:x.__dict__
232 | ).values.tolist())
233 | args_df = args_df.explode("option_strings")
234 | args_df["option_strings"] = args_df["option_strings"].map(
235 | lambda x: x[2:] if x.startswith("--") else x
236 | ).map(
237 | lambda x: x[1:] if x.startswith("-") else x
238 | )
239 | args_df = args_df[["option_strings", "default"]]
240 | args = dict(args_df.values.tolist())
241 | if as_named_tuple:
242 | args_parser_namedtuple = namedtuple("args_config", args)
243 | return args_parser_namedtuple(**args)
244 | return args_df
245 |
246 | def transform_named_tuple_to_dict(N_tuple):
247 | return dict(map(
248 | lambda x: (x, getattr(N_tuple, x))
249 | ,filter(lambda x: not x.startswith("_") ,dir(N_tuple))
250 | ))
251 |
252 | def transform_dict_to_named_tuple(dict_, name = "NamedTuple"):
253 | args_parser_namedtuple = namedtuple(name, dict_)
254 | return args_parser_namedtuple(**dict_)
255 |
256 | def setattr_gen_option_cls(src_obj):
257 | assert isinstance(src_obj, tuple) or isinstance(src_obj, dict)
258 | if isinstance(src_obj, tuple):
259 | src_obj_ = transform_named_tuple_to_dict(src_obj)
260 | else:
261 | src_obj_ = src_obj
262 | assert isinstance(src_obj_, dict)
263 | class Option(object):
264 | pass
265 | option = Option()
266 | for k, v in src_obj_.items():
267 | setattr(option, k, v)
268 | return option
269 |
270 | args = parse_args()
271 | args = parse_parser_add_arg(args, as_named_tuple = True)
272 | '''
273 | export MODEL_NAME="stable-diffusion-v1-4/"
274 | export dataset_name="lambdalabs/pokemon-blip-captions"
275 |
276 | accelerate launch train_text_to_image_ori.py \
277 | --pretrained_model_name_or_path=$MODEL_NAME \
278 | --dataset_name=$dataset_name \
279 | --use_ema \
280 | --resolution=32 --center_crop --random_flip \
281 | --train_batch_size=1 \
282 | --gradient_accumulation_steps=4 \
283 | --gradient_checkpointing \
284 | --max_train_steps=15000 \
285 | --learning_rate=1e-05 \
286 | --max_grad_norm=1 \
287 | --lr_scheduler="constant" --lr_warmup_steps=0 \
288 | --output_dir="sd-pokemon-model"
289 | '''
290 | args_dict = transform_named_tuple_to_dict(args)
291 | args_dict["pretrained_model_name_or_path"] = "stable-diffusion-v1-4/"
292 | #args_dict["pretrained_model_name_or_path"] = "japanese-stable-diffusion/"
293 | args_dict["dataset_name"] = "svjack/pokemon-blip-captions-en-ja"
294 | args_dict["use_ema"] = True
295 | ###args_dict["use_ema"] = False
296 | args_dict["resolution"] = 256
297 | args_dict["center_crop"] = True
298 | args_dict["random_flip"] = True
299 | args_dict["train_batch_size"] = 1
300 | args_dict["gradient_accumulation_steps"] = 4
301 | args_dict["train_batch_size"] = 4
302 | args_dict["gradient_checkpointing"] = True
303 | #### to 15000
304 | args_dict["max_train_steps"] = 50000
305 | args_dict["learning_rate"] = 1e-05
306 | args_dict["max_grad_norm"] = 1
307 | args_dict["lr_scheduler"] = "constant"
308 | args_dict["lr_warmup_steps"] = 0
309 | args_dict["output_dir"] = "sd-pokemon-model"
310 | args_dict["caption_column"] = "en_text"
311 | args_dict["mixed_precision"] = "no"
312 | args = transform_dict_to_named_tuple(args_dict)
313 |
314 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
315 | if token is None:
316 | token = HfFolder.get_token()
317 | if organization is None:
318 | username = whoami(token)["name"]
319 | return f"{username}/{model_id}"
320 | else:
321 | return f"{organization}/{model_id}"
322 |
323 | dataset_name_mapping = {
324 | "svjack/pokemon-blip-captions-en-ja": ("image", "en_text"),
325 | }
326 |
327 |
328 | class EMAModel:
329 | """
330 | Exponential Moving Average of models weights
331 | """
332 |
333 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
334 | parameters = list(parameters)
335 | self.shadow_params = [p.clone().detach() for p in parameters]
336 |
337 | self.decay = decay
338 | self.optimization_step = 0
339 |
340 | def get_decay(self, optimization_step):
341 | """
342 | Compute the decay factor for the exponential moving average.
343 | """
344 | value = (1 + optimization_step) / (10 + optimization_step)
345 | return 1 - min(self.decay, value)
346 |
347 | @torch.no_grad()
348 | def step(self, parameters):
349 | parameters = list(parameters)
350 |
351 | self.optimization_step += 1
352 | self.decay = self.get_decay(self.optimization_step)
353 |
354 | for s_param, param in zip(self.shadow_params, parameters):
355 | if param.requires_grad:
356 | tmp = self.decay * (s_param - param)
357 | s_param.sub_(tmp)
358 | else:
359 | s_param.copy_(param)
360 |
361 | torch.cuda.empty_cache()
362 |
363 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
364 | """
365 | Copy current averaged parameters into given collection of parameters.
366 |
367 | Args:
368 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
369 | updated with the stored moving averages. If `None`, the
370 | parameters with which this `ExponentialMovingAverage` was
371 | initialized will be used.
372 | """
373 | parameters = list(parameters)
374 | for s_param, param in zip(self.shadow_params, parameters):
375 | param.data.copy_(s_param.data)
376 |
377 | def to(self, device=None, dtype=None) -> None:
378 | r"""Move internal buffers of the ExponentialMovingAverage to `device`.
379 |
380 | Args:
381 | device: like `device` argument to `torch.Tensor.to`
382 | """
383 | # .to() on the tensors handles None correctly
384 | self.shadow_params = [
385 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
386 | for p in self.shadow_params
387 | ]
388 |
389 | #### model prepare
390 | '''
391 | from japanese_stable_diffusion import JapaneseStableDiffusionPipeline
392 | import torch
393 | import pandas as pd
394 |
395 | from torch import autocast
396 | from diffusers import LMSDiscreteScheduler
397 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012,
398 | beta_schedule="scaled_linear", num_train_timesteps=1000)
399 |
400 | # pretrained_model_name_or_path = "jap_to_zh_35000"
401 | pretrained_model_name_or_path = "japanese-stable-diffusion/"
402 | pipe = JapaneseStableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path,
403 | scheduler=scheduler, use_auth_token=True)
404 | tokenizer, text_encoder, vae, unet = pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet
405 | '''
406 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
407 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
408 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
409 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
410 |
411 | vae.requires_grad_(False)
412 | text_encoder.requires_grad_(False)
413 |
414 | if args.gradient_checkpointing:
415 | unet.enable_gradient_checkpointing()
416 |
417 | if args.scale_lr:
418 | args.learning_rate = (
419 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
420 | )
421 |
422 | # Initialize the optimizer
423 | if args.use_8bit_adam:
424 | try:
425 | import bitsandbytes as bnb
426 | except ImportError:
427 | raise ImportError(
428 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
429 | )
430 |
431 | optimizer_cls = bnb.optim.AdamW8bit
432 | else:
433 | optimizer_cls = torch.optim.AdamW
434 |
435 | optimizer = optimizer_cls(
436 | unet.parameters(),
437 | lr=args.learning_rate,
438 | betas=(args.adam_beta1, args.adam_beta2),
439 | weight_decay=args.adam_weight_decay,
440 | eps=args.adam_epsilon,
441 | )
442 | noise_scheduler = DDPMScheduler(
443 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000,
444 | #tensor_format="pt"
445 | )
446 |
447 |
448 | dataset = load_dataset(
449 | args.dataset_name,
450 | args.dataset_config_name,
451 | cache_dir=args.cache_dir,
452 | )
453 | column_names = dataset["train"].column_names
454 | dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
455 | if args.image_column is None:
456 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
457 | else:
458 | image_column = args.image_column
459 | if image_column not in column_names:
460 | raise ValueError(
461 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
462 | )
463 | if args.caption_column is None:
464 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
465 | else:
466 | caption_column = args.caption_column
467 | if caption_column not in column_names:
468 | raise ValueError(
469 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
470 | )
471 |
472 |
473 | # Preprocessing the datasets.
474 | # We need to tokenize input captions and transform the images.
475 | def tokenize_captions(examples, is_train=True):
476 | captions = []
477 | for caption in examples[caption_column]:
478 | if isinstance(caption, str):
479 | captions.append(caption)
480 | elif isinstance(caption, (list, np.ndarray)):
481 | # take a random caption if there are multiple
482 | captions.append(random.choice(caption) if is_train else caption[0])
483 | else:
484 | raise ValueError(
485 | f"Caption column `{caption_column}` should contain either strings or lists of strings."
486 | )
487 | inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True)
488 | input_ids = inputs.input_ids
489 | return input_ids
490 |
491 |
492 | train_transforms = transforms.Compose(
493 | [
494 | transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR),
495 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
496 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
497 | transforms.ToTensor(),
498 | transforms.Normalize([0.5], [0.5]),
499 | ]
500 | )
501 |
502 | def preprocess_train(examples):
503 | images = [image.convert("RGB") for image in examples[image_column]]
504 | examples["pixel_values"] = [train_transforms(image) for image in images]
505 | examples["input_ids"] = tokenize_captions(examples)
506 | return examples
507 |
508 | if args.max_train_samples is not None:
509 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
510 | train_dataset = dataset["train"].with_transform(preprocess_train)
511 |
512 | def collate_fn(examples):
513 | pixel_values = torch.stack([example["pixel_values"] for example in examples])
514 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
515 | input_ids = [example["input_ids"] for example in examples]
516 | padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt")
517 | return {
518 | "pixel_values": pixel_values,
519 | "input_ids": padded_tokens.input_ids,
520 | "attention_mask": padded_tokens.attention_mask,
521 | }
522 |
523 | train_dataloader = torch.utils.data.DataLoader(
524 | train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size
525 | )
526 |
527 |
528 | overrode_max_train_steps = False
529 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
530 | if args.max_train_steps is None:
531 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
532 | overrode_max_train_steps = True
533 |
534 | lr_scheduler = get_scheduler(
535 | args.lr_scheduler,
536 | optimizer=optimizer,
537 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
538 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
539 | )
540 |
541 | def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
542 | # Initialize accelerator and tensorboard logging
543 | args = config
544 | unet = model
545 |
546 | logging_dir = os.path.join(args.output_dir, args.logging_dir)
547 | accelerator = Accelerator(
548 | gradient_accumulation_steps=args.gradient_accumulation_steps,
549 | mixed_precision=args.mixed_precision,
550 | log_with=args.report_to,
551 | logging_dir=logging_dir,
552 | )
553 | logging.basicConfig(
554 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
555 | datefmt="%m/%d/%Y %H:%M:%S",
556 | level=logging.INFO,
557 | )
558 | if args.seed is not None:
559 | set_seed(args.seed)
560 | if accelerator.is_main_process:
561 | os.makedirs(args.output_dir, exist_ok=True)
562 |
563 | weight_dtype = torch.float32
564 | if args.mixed_precision == "fp16":
565 | weight_dtype = torch.float16
566 | elif args.mixed_precision == "bf16":
567 | weight_dtype = torch.bfloat16
568 |
569 | text_encoder.to(accelerator.device, dtype=weight_dtype)
570 | vae.to(accelerator.device, dtype=weight_dtype)
571 |
572 | if args.use_ema:
573 | ema_unet = EMAModel(unet.parameters())
574 | ema_unet.to(accelerator.device, dtype=weight_dtype)
575 |
576 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
577 | if overrode_max_train_steps:
578 | args_dict = transform_named_tuple_to_dict(args)
579 | args_dict["max_train_steps"] = args.num_train_epochs * num_update_steps_per_epoch
580 | args = transform_dict_to_named_tuple(args_dict)
581 |
582 |
583 | args_dict = transform_named_tuple_to_dict(args)
584 | args_dict["num_train_epochs"] = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
585 | args = transform_dict_to_named_tuple(args_dict)
586 |
587 | if accelerator.is_main_process:
588 | if config.push_to_hub:
589 | #repo = init_git_repo(config, at_init=True)
590 | pass
591 | accelerator.init_trackers("train_example")
592 |
593 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
594 |
595 | logger.info("***** Running training *****")
596 | logger.info(f" Num examples = {len(train_dataset)}")
597 | logger.info(f" Num Epochs = {args.num_train_epochs}")
598 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
599 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
600 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
601 | logger.info(f" Total optimization steps = {args.max_train_steps}")
602 |
603 | # Prepare everything
604 | # There is no specific order to remember, you just need to unpack the
605 | # objects in the same order you gave them to the prepare method.
606 | '''
607 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
608 | model, optimizer, train_dataloader, lr_scheduler
609 | )
610 | '''
611 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
612 | unet, optimizer, train_dataloader, lr_scheduler
613 | )
614 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
615 | progress_bar.set_description("Steps")
616 | global_step = 0
617 |
618 | for epoch in range(args.num_train_epochs):
619 | #unet.train()
620 | train_loss = 0.0
621 | for step, batch in enumerate(train_dataloader):
622 | with accelerator.accumulate(unet):
623 | # Convert images to latent space
624 | latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
625 | latents = latents * 0.18215
626 |
627 | # Sample noise that we'll add to the latents
628 | noise = torch.randn_like(latents)
629 | bsz = latents.shape[0]
630 | # Sample a random timestep for each image
631 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
632 | timesteps = timesteps.long()
633 |
634 | # Add noise to the latents according to the noise magnitude at each timestep
635 | # (this is the forward diffusion process)
636 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
637 |
638 | # Get the text embedding for conditioning
639 | encoder_hidden_states = text_encoder(batch["input_ids"])[0]
640 |
641 | # Predict the noise residual and compute loss
642 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
643 |
644 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
645 | #loss = F.mse_loss(noise.float(), noise.float(), reduction="mean")
646 |
647 | # Gather the losses across all processes for logging (if we use distributed training).
648 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
649 | train_loss += avg_loss.item() / args.gradient_accumulation_steps
650 |
651 | # Backpropagate
652 | accelerator.backward(loss)
653 | if accelerator.sync_gradients:
654 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
655 | optimizer.step()
656 | lr_scheduler.step()
657 | optimizer.zero_grad()
658 |
659 | # Checks if the accelerator has performed an optimization step behind the scenes
660 | if accelerator.sync_gradients:
661 | if args.use_ema:
662 | ema_unet.step(unet.parameters())
663 | progress_bar.update(1)
664 | global_step += 1
665 | accelerator.log({"train_loss": train_loss}, step=global_step)
666 | train_loss = 0.0
667 |
668 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
669 | progress_bar.set_postfix(**logs)
670 |
671 | if global_step >= args.max_train_steps:
672 | break
673 |
674 | from accelerate import notebook_launcher
675 |
676 | #### train it.
677 | args_ = (args, unet, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
678 | notebook_launcher(train_loop, args_, num_processes=1)
679 |
680 | from diffusers import LMSDiscreteScheduler
681 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012,
682 | beta_schedule="scaled_linear", num_train_timesteps=1000)
683 |
684 | #### save to local
685 | pipeline = StableDiffusionPipeline(
686 | text_encoder=text_encoder,
687 | vae=vae,
688 | unet=unet,
689 | tokenizer=tokenizer,
690 | scheduler=scheduler,
691 | safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
692 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
693 | )
694 | pipeline.save_pretrained("en_model_26000")
695 | pipeline.safety_checker = lambda images, clip_input: (images, False)
696 | pipeline = pipeline.to("cuda")
697 |
698 |
699 | imgs = pipeline("A cartoon character with a potted plant on his head",
700 | num_inference_steps = 100
701 | )
702 | imgs.images[0]
703 |
704 | imgs = pipeline("cartoon bird",
705 | num_inference_steps = 100
706 | )
707 | imgs.images[0]
708 |
709 | imgs = pipeline("yellow ball",
710 | num_inference_steps = 100
711 | )
712 | imgs.images[0]
713 |
714 | imgs = pipeline("blue dragon illustration",
715 | num_inference_steps = 100
716 | )
717 | imgs.images[0]
718 |
--------------------------------------------------------------------------------