├── imgs ├── en_bird.jpg ├── en_plant.jpg ├── ja_bird.jpg ├── ja_plant.jpg ├── zh_bird.jpg ├── zh_plant.jpg ├── en_blue_dragon.jpg ├── ja_blue_dragon.jpg └── zh_blue_dragon.jpg ├── requirements.txt ├── install.sh ├── install.md ├── run_en_model.py ├── run_ja_model.py ├── README.md ├── README_EN.md ├── run_zh_model.py ├── train_zh_model.py ├── train_ja_model.py └── train_en_model.py /imgs/en_bird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/en_bird.jpg -------------------------------------------------------------------------------- /imgs/en_plant.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/en_plant.jpg -------------------------------------------------------------------------------- /imgs/ja_bird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/ja_bird.jpg -------------------------------------------------------------------------------- /imgs/ja_plant.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/ja_plant.jpg -------------------------------------------------------------------------------- /imgs/zh_bird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/zh_bird.jpg -------------------------------------------------------------------------------- /imgs/zh_plant.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/zh_plant.jpg -------------------------------------------------------------------------------- /imgs/en_blue_dragon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/en_blue_dragon.jpg -------------------------------------------------------------------------------- /imgs/ja_blue_dragon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/ja_blue_dragon.jpg -------------------------------------------------------------------------------- /imgs/zh_blue_dragon.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svjack/Stable-Diffusion-Pokemon/HEAD/imgs/zh_blue_dragon.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.4.1 2 | accelerate 3 | torchvision 4 | transformers>=4.21.0 5 | ftfy 6 | tensorboard 7 | modelcards 8 | jieba 9 | pandas 10 | datasets 11 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | pip install git+https://github.com/rinnakk/japanese-stable-diffusion 2 | pip install -r requirements.txt 3 | 4 | sudo apt-get install git-lfs 5 | git clone https://huggingface.co/rinna/japanese-stable-diffusion 6 | git clone https://huggingface.co/CompVis/stable-diffusion-v1-4 7 | 8 | #### pretrained download 9 | git clone https://huggingface.co/svjack/Stable-Diffusion-Pokemon-en 10 | git clone https://huggingface.co/svjack/Stable-Diffusion-Pokemon-ja 11 | git clone https://huggingface.co/svjack/Stable-Diffusion-Pokemon-zh 12 | -------------------------------------------------------------------------------- /install.md: -------------------------------------------------------------------------------- 1 | pip install git+https://github.com/rinnakk/japanese-stable-diffusion 2 | ####pip install diffusers 3 | huggingface-cli login 4 | 5 | sudo apt-get install git-lfs 6 | git clone https://huggingface.co/rinna/japanese-stable-diffusion 7 | 8 | #### from text_to_image_train.py requirements: https://github.com/huggingface/diffusers/tree/main/examples/text_to_image 9 | git clone https://github.com/huggingface/diffusers 10 | cd examples/text_to_image 11 | pip install -r requirements.txt 12 | 13 | diffusers==0.4.1 14 | accelerate 15 | torchvision 16 | transformers>=4.21.0 17 | ftfy 18 | tensorboard 19 | modelcards 20 | jieba 21 | pandas 22 | datasets 23 | -------------------------------------------------------------------------------- /run_en_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | 4 | from torch import autocast 5 | from diffusers import LMSDiscreteScheduler, StableDiffusionPipeline 6 | 7 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, 8 | beta_schedule="scaled_linear", num_train_timesteps=1000) 9 | 10 | #pretrained_model_name_or_path = "en_model_26000" 11 | pretrained_model_name_or_path = "svjack/Stable-Diffusion-Pokemon-en" 12 | pipe = StableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, 13 | scheduler=scheduler, use_auth_token=True) 14 | 15 | pipe = pipe.to("cuda") 16 | 17 | #### disable safety_checker 18 | pipe.safety_checker = lambda images, clip_input: (images, False) 19 | 20 | imgs = pipe("A cartoon character with a potted plant on his head", 21 | num_inference_steps = 100 22 | ) 23 | imgs.images[0] 24 | 25 | imgs = pipe("cartoon bird", 26 | num_inference_steps = 100 27 | ) 28 | imgs.images[0] 29 | 30 | imgs = pipe("yellow ball", 31 | num_inference_steps = 100 32 | ) 33 | imgs.images[0] 34 | 35 | imgs = pipe("blue dragon illustration", 36 | num_inference_steps = 100 37 | ) 38 | imgs.images[0] 39 | 40 | ###### person "Zhuge Liang" 41 | ###### penis 42 | -------------------------------------------------------------------------------- /run_ja_model.py: -------------------------------------------------------------------------------- 1 | from japanese_stable_diffusion import JapaneseStableDiffusionPipeline 2 | import torch 3 | import pandas as pd 4 | 5 | from torch import autocast 6 | from diffusers import LMSDiscreteScheduler 7 | 8 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, 9 | beta_schedule="scaled_linear", num_train_timesteps=1000) 10 | 11 | #pretrained_model_name_or_path = "jap_model_26000" 12 | 13 | #### sudo apt-get install git-lfs 14 | #### git clone https://huggingface.co/svjack/Stable-Diffusion-Pokemon-ja 15 | pretrained_model_name_or_path = "Stable-Diffusion-Pokemon-ja" 16 | 17 | pipe = JapaneseStableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, 18 | scheduler=scheduler, use_auth_token=True) 19 | 20 | pipe = pipe.to("cuda") 21 | 22 | #### disable safety_checker 23 | pipe.safety_checker = lambda images, clip_input: (images, False) 24 | 25 | imgs = pipe("鉢植えの植物を頭に載せた漫画のキャラクター", 26 | num_inference_steps = 100 27 | ) 28 | imgs.images[0] 29 | 30 | imgs = pipe("漫画の鳥", 31 | num_inference_steps = 100 32 | ) 33 | imgs.images[0] 34 | 35 | imgs = pipe("黄色いボール", 36 | num_inference_steps = 100 37 | ) 38 | imgs.images[0] 39 | 40 | imgs = pipe("ブルードラゴンのイラスト", 41 | num_inference_steps = 100 42 | ) 43 | imgs.images[0] 44 | 45 | ##### 緑のピエロ 46 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

4 |

Stable-Diffusion-Pokemon

5 | 6 |

7 | 在宝可梦数据集(Pokemon-Blip-Captions)的英语、日语、中文版本上微调Stable Diffusion的示例 8 |
9 |

10 |

11 | 12 | [In English](README_EN.md) 13 | 14 | ### 简要引述 15 | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release)是现在一流的文本转图片生成模型。
16 | 现如今,借助于提供多模态能力的[diffusers](https://github.com/huggingface/diffusers)工程,人们可以自定义它们的条件或非条件图像(是否以文本提示作为条件)生成模型。
17 | 这个工程的目标是实现diffuser提供的基于[lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions)的[text to image example](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image)微调任务,并将该任务迁移到日文及中文数据集上进行实践。
18 | 通过比较结论将会对Stable Diffusion在不同语言上的微调任务给出指导。 19 | 所有的代码都依据官方的[train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)进行修改,并使得其对于中文和日文是有效的。 20 | 得到的三种语言的训练模型如下:[English](https://huggingface.co/svjack/Stable-Diffusion-Pokemon-en) , [Japanese](https://huggingface.co/svjack/Stable-Diffusion-Pokemon-ja) 及 [Chinese](https://huggingface.co/svjack/Stable-Diffusion-Pokemon-zh). 21 | 22 | ### 安装和运行 23 | 运行install.sh将会安装所有的依赖并下载所有需要的模型(保证您注册并具有huggingface账号以及您的[token](https://huggingface.co/docs/hub/security-tokens)) 24 | 下载后可尝试运行[run_en_model.py](run_en_model.py), [run_ja_model.py](run_ja_model.py) 及 [run_zh_model.py](run_zh_model.py) 25 | 26 | ### 数据集准备 27 | 为了在日文及中文领域进行调试,我们需要将[lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions)。我已经使用 [DeepL](https://www.deepl.com/translator) 对其进行翻译并上传至 huggingface dataset hub。分别位于 [svjack/pokemon-blip-captions-en-ja](https://huggingface.co/datasets/svjack/pokemon-blip-captions-en-ja) 和 [svjack/pokemon-blip-captions-en-zh](https://huggingface.co/datasets/svjack/pokemon-blip-captions-en-zh). 28 | 29 | 30 | ### 微调预训练模型 31 | 英文版本是一个仅仅将[train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py)的脚本由函数更改为notebook模式,位于[train_en_model.py](train_en_model.py)
32 | 33 | 日文版本使用[rinnakk/japanese-stable-diffusion](https://github.com/rinnakk/japanese-stable-diffusion)替换预训练模型,位于[train_ja_model.py](train_ja_model.py)
34 | 35 | 中文版本使用[IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese](https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese) 替换预训练分词器以及文本编码器,并将BertForTokenClassification的logit输出进行padding来代替CLIPTextModel,位于[train_zh_model.py](train_zh_model.py)。
36 | 37 | 为了在推断阶段使得所有结果可见,我关闭了safety_checker。 38 | 39 | ### 生成器结果比较 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 |
Images
PromptEnglishJapaneseChinese
A cartoon character with a potted plant on his head

鉢植えの植物を頭に載せた漫画のキャラクター

一个头上戴着盆栽的卡通人物
Girl in a jacketGirl in a jacketGirl in a jacket
cartoon bird

漫画の鳥

卡通鸟
Girl in a jacketGirl in a jacketGirl in a jacket
blue dragon illustration

ブルードラゴンのイラスト

蓝色的龙图
Girl in a jacketGirl in a jacketGirl in a jacket
72 | 73 | ### 讨论 74 | 在英文、日文、中文下的预训练模型分别训练了26000, 26000 及 20000步。
75 | 对于训练结果的解释是这样的:[rinnakk/japanese-stable-diffusion](https://github.com/rinnakk/japanese-stable-diffusion)由于是日文的原生模型,所以含有很多宝可梦的特征。[Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release)在英文领域微调的很好。[IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese](https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese) 如模型卡片所说在中文上能够起到基本的文本表征作用。 76 | 77 | 78 | ## Contact 79 | 80 | 83 | svjack - svjackbt@gmail.com - ehangzhou@outlook.com 84 | 85 | 88 | Project Link:[https://github.com/svjack/Stable-Diffusion-Pokemon](https://github.com/svjack/Stable-Diffusion-Pokemon) 89 | 90 | 91 | 92 | ## Acknowledgements 93 | 106 | * [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) 107 | * [diffusers](https://github.com/huggingface/diffusers) 108 | * [DeepL](https://www.deepl.com/translator) 109 | * [rinnakk/japanese-stable-diffusion](https://github.com/rinnakk/japanese-stable-diffusion) 110 | * [IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese](https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese) 111 | * [svjack](https://huggingface.co/svjack) 112 | -------------------------------------------------------------------------------- /README_EN.md: -------------------------------------------------------------------------------- 1 | 2 |
3 |

4 |

Stable-Diffusion-Pokemon

5 | 6 |

7 | A demo of fine tune Stable Diffusion on Pokemon-Blip-Captions in English, Japanese and Chinese Corpus 8 |
9 |

10 |

11 | 12 | [中文简介](README.md) 13 | 14 | ### Brief introduction 15 | [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) is a state of the art text-to-image model that generates images from text.
16 | Nowadays, with the help of [diffusers](https://github.com/huggingface/diffusers), which provides pretrained diffusion models across multiple modalities, people can customize their own image generator conditional (based on prompt) or unconditional.
17 | This project focus on run the [text to image example](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image) diffusers provided based on [lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) and migrate this task to Japanese and Chinese domain 18 | in both model and data dimensions. Compare the conclusions that may give a guideline about the fine tuning of Stable Diffusion in different languages.
19 | All codes are edit versions of the official [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) that make it works in Japanese and Chinese Domain. 20 | And provide three pretrained models in [English](https://huggingface.co/svjack/Stable-Diffusion-Pokemon-en) , [Japanese](https://huggingface.co/svjack/Stable-Diffusion-Pokemon-ja) and [Chinese](https://huggingface.co/svjack/Stable-Diffusion-Pokemon-zh). 21 | 22 | ### Installtion and Running 23 | Running install.sh will install all dependencies and download all models needed.(make sure you have login your huggingface account and have your [token](https://huggingface.co/docs/hub/security-tokens)) 24 | After download, you can try [run_en_model.py](run_en_model.py), [run_ja_model.py](run_ja_model.py) and [run_zh_model.py](run_zh_model.py) by yourself. 25 | 26 | ### DataSet prepare 27 | For fine tuning them in Japanese and Chinese domains. All we need is the [lambdalabs/pokemon-blip-captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) in Japanese and Chinese. I have translated them with the help of [DeepL](https://www.deepl.com/translator) and upload them to huggingface dataset hub. Located in [svjack/pokemon-blip-captions-en-ja](https://huggingface.co/datasets/svjack/pokemon-blip-captions-en-ja) and [svjack/pokemon-blip-captions-en-zh](https://huggingface.co/datasets/svjack/pokemon-blip-captions-en-zh). 28 | 29 | ### Fine tuning pretrained models 30 | The English version located in [train_en_model.py](train_en_model.py) is only a simply copy of [train_text_to_image.py](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py) change the code used by accelerate in script to notebook by function 31 | ```python 32 | notebook_launcher 33 | ``` 34 | 35 | The Japanese version located in [train_ja_model.py](train_ja_model.py) 36 | replaced the pretrained model by [rinnakk/japanese-stable-diffusion](https://github.com/rinnakk/japanese-stable-diffusion)
37 | 38 | The Chinese version located in [train_zh_model.py](train_zh_model.py) 39 | replaced the pretrained tokenizer and text_encoder by [IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese](https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese) and use logit output from BertForTokenClassification with padding to the downstream network to replace CLIPTextModel. 40 | 41 | For take a look at all outputs, i disable the safety_checker to let all outputs without covered in the inference steps. 42 | 43 | ### Generator Results comparison 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 |
Images
PromptEnglishJapaneseChinese
A cartoon character with a potted plant on his head

鉢植えの植物を頭に載せた漫画のキャラクター

一个头上戴着盆栽的卡通人物
Girl in a jacketGirl in a jacketGirl in a jacket
cartoon bird

漫画の鳥

卡通鸟
Girl in a jacketGirl in a jacketGirl in a jacket
blue dragon illustration

ブルードラゴンのイラスト

蓝色的龙图
Girl in a jacketGirl in a jacketGirl in a jacket
76 | 77 | ### Discussion 78 | The pretrained models in English, Japanese and Chinese are trained for 26000, 26000 and 20000 steps respectively. The Japanese outperform others and the Chinese version seems the third. The interpretation can be 79 | [rinnakk/japanese-stable-diffusion](https://github.com/rinnakk/japanese-stable-diffusion) have many culture and features about Pokemon, [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) in English domain is finetuned favourable. [IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese](https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese) in Chinese as the model card introduction is only a text feature finetuned version. 80 | 81 | 82 | ## Contact 83 | 84 | 87 | svjack - svjackbt@gmail.com - ehangzhou@outlook.com 88 | 89 | 92 | Project Link:[https://github.com/svjack/Stable-Diffusion-Pokemon](https://github.com/svjack/Stable-Diffusion-Pokemon) 93 | 94 | 95 | 96 | ## Acknowledgements 97 | 110 | * [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) 111 | * [diffusers](https://github.com/huggingface/diffusers) 112 | * [DeepL](https://www.deepl.com/translator) 113 | * [rinnakk/japanese-stable-diffusion](https://github.com/rinnakk/japanese-stable-diffusion) 114 | * [IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese](https://huggingface.co/IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese) 115 | * [svjack](https://huggingface.co/svjack) 116 | -------------------------------------------------------------------------------- /run_zh_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pandas as pd 3 | 4 | from torch import autocast 5 | from diffusers import LMSDiscreteScheduler 6 | 7 | import torch 8 | from transformers import BertForSequenceClassification, BertConfig, BertTokenizer, BertForTokenClassification 9 | from transformers import CLIPProcessor, CLIPModel 10 | import numpy as np 11 | 12 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import * 13 | from japanese_stable_diffusion.pipeline_stable_diffusion import * 14 | 15 | class StableDiffusionPipelineWrapper(StableDiffusionPipeline): 16 | 17 | @torch.no_grad() 18 | def __call__( 19 | self, 20 | prompt: Union[str, List[str]], 21 | height: int = 512, 22 | width: int = 512, 23 | num_inference_steps: int = 50, 24 | guidance_scale: float = 7.5, 25 | negative_prompt: Optional[Union[str, List[str]]] = None, 26 | num_images_per_prompt: Optional[int] = 1, 27 | eta: float = 0.0, 28 | generator: Optional[torch.Generator] = None, 29 | latents: Optional[torch.FloatTensor] = None, 30 | output_type: Optional[str] = "pil", 31 | return_dict: bool = True, 32 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 33 | callback_steps: Optional[int] = 1, 34 | **kwargs, 35 | ): 36 | if isinstance(prompt, str): 37 | batch_size = 1 38 | elif isinstance(prompt, list): 39 | batch_size = len(prompt) 40 | else: 41 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 42 | 43 | if height % 8 != 0 or width % 8 != 0: 44 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 45 | 46 | if (callback_steps is None) or ( 47 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 48 | ): 49 | raise ValueError( 50 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 51 | f" {type(callback_steps)}." 52 | ) 53 | 54 | # get prompt text embeddings 55 | text_inputs = self.tokenizer( 56 | prompt, 57 | padding="max_length", 58 | max_length=self.tokenizer.model_max_length, 59 | return_tensors="pt", 60 | ) 61 | text_input_ids = text_inputs.input_ids 62 | 63 | if text_input_ids.shape[-1] > self.tokenizer.model_max_length: 64 | removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :]) 65 | logger.warning( 66 | "The following part of your input was truncated because CLIP can only handle sequences up to" 67 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 68 | ) 69 | text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] 70 | text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] 71 | 72 | # duplicate text embeddings for each generation per prompt, using mps friendly method 73 | bs_embed, seq_len, _ = text_embeddings.shape 74 | text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) 75 | text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) 76 | 77 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 78 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 79 | # corresponds to doing no classifier free guidance. 80 | do_classifier_free_guidance = guidance_scale > 1.0 81 | # get unconditional embeddings for classifier free guidance 82 | if do_classifier_free_guidance: 83 | uncond_tokens: List[str] 84 | if negative_prompt is None: 85 | uncond_tokens = [""] 86 | elif type(prompt) is not type(negative_prompt): 87 | raise TypeError( 88 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 89 | f" {type(prompt)}." 90 | ) 91 | elif isinstance(negative_prompt, str): 92 | uncond_tokens = [negative_prompt] 93 | elif batch_size != len(negative_prompt): 94 | raise ValueError( 95 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 96 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 97 | " the batch size of `prompt`." 98 | ) 99 | else: 100 | uncond_tokens = negative_prompt 101 | 102 | max_length = text_input_ids.shape[-1] 103 | uncond_input = self.tokenizer( 104 | uncond_tokens, 105 | padding="max_length", 106 | max_length=max_length, 107 | truncation=True, 108 | return_tensors="pt", 109 | ) 110 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 111 | 112 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 113 | seq_len = uncond_embeddings.shape[1] 114 | uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1) 115 | uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1) 116 | 117 | # For classifier free guidance, we need to do two forward passes. 118 | # Here we concatenate the unconditional and text embeddings into a single batch 119 | # to avoid doing two forward passes 120 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 121 | 122 | # get the initial random noise unless the user supplied it 123 | 124 | # Unlike in other pipelines, latents need to be generated in the target device 125 | # for 1-to-1 results reproducibility with the CompVis implementation. 126 | # However this currently doesn't work in `mps`. 127 | latents_shape = (batch_size * num_images_per_prompt, self.unet.in_channels, height // 8, width // 8) 128 | latents_dtype = text_embeddings.dtype 129 | if latents is None: 130 | if self.device.type == "mps": 131 | # randn does not work reproducibly on mps 132 | latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( 133 | self.device 134 | ) 135 | else: 136 | latents = torch.randn(latents_shape, generator=generator, device=self.device, dtype=latents_dtype) 137 | else: 138 | if latents.shape != latents_shape: 139 | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") 140 | latents = latents.to(self.device) 141 | 142 | # set timesteps 143 | self.scheduler.set_timesteps(num_inference_steps) 144 | 145 | # Some schedulers like PNDM have timesteps as arrays 146 | # It's more optimized to move all timesteps to correct device beforehand 147 | timesteps_tensor = self.scheduler.timesteps.to(self.device) 148 | 149 | # scale the initial noise by the standard deviation required by the scheduler 150 | latents = latents * self.scheduler.init_noise_sigma 151 | 152 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 153 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 154 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 155 | # and should be between [0, 1] 156 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 157 | extra_step_kwargs = {} 158 | if accepts_eta: 159 | extra_step_kwargs["eta"] = eta 160 | 161 | for i, t in enumerate(self.progress_bar(timesteps_tensor)): 162 | # expand the latents if we are doing classifier free guidance 163 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 164 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 165 | 166 | # predict the noise residual 167 | ###text_embeddings 168 | #print("before :" ,text_embeddings.shape) 169 | eh_shape = text_embeddings.shape 170 | if i == 0: 171 | eh_pad = torch.zeros((eh_shape[0], eh_shape[1], 768 - 512)) 172 | eh_pad = eh_pad.to(self.device) 173 | text_embeddings = torch.concat([text_embeddings, eh_pad], -1) 174 | 175 | #print("after :" ,text_embeddings.shape) 176 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 177 | 178 | # perform guidance 179 | if do_classifier_free_guidance: 180 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 181 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 182 | 183 | # compute the previous noisy sample x_t -> x_t-1 184 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 185 | 186 | # call the callback, if provided 187 | if callback is not None and i % callback_steps == 0: 188 | callback(i, t, latents) 189 | 190 | latents = 1 / 0.18215 * latents 191 | image = self.vae.decode(latents).sample 192 | 193 | image = (image / 2 + 0.5).clamp(0, 1) 194 | 195 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 196 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 197 | 198 | if self.safety_checker is not None: 199 | safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( 200 | self.device 201 | ) 202 | image, has_nsfw_concept = self.safety_checker( 203 | images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) 204 | ) 205 | else: 206 | has_nsfw_concept = None 207 | 208 | if output_type == "pil": 209 | image = self.numpy_to_pil(image) 210 | 211 | if not return_dict: 212 | return (image, has_nsfw_concept) 213 | 214 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 215 | 216 | if __name__ == "__main__": 217 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, 218 | beta_schedule="scaled_linear", num_train_timesteps=1000) 219 | 220 | #pretrained_model_name_or_path = "zh_model_20000" 221 | #### sudo apt-get install git-lfs 222 | #### git clone https://huggingface.co/svjack/Stable-Diffusion-Pokemon-zh 223 | pretrained_model_name_or_path = "Stable-Diffusion-Pokemon-zh" 224 | 225 | tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder = "tokenizer") 226 | text_encoder = BertForTokenClassification.from_pretrained(pretrained_model_name_or_path, subfolder = "text_encoder") 227 | 228 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") 229 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet") 230 | 231 | tokenizer.model_max_length = 77 232 | pipeline_wrap = StableDiffusionPipelineWrapper( 233 | text_encoder=text_encoder, 234 | vae=vae, 235 | unet=unet, 236 | tokenizer=tokenizer, 237 | scheduler=scheduler, 238 | safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), 239 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), 240 | ) 241 | pipeline_wrap.safety_checker = lambda images, clip_input: (images, False) 242 | pipeline_wrap = pipeline_wrap.to("cuda") 243 | 244 | imgs = pipeline_wrap("一个头上戴着盆栽的卡通人物", 245 | num_inference_steps = 100 246 | ) 247 | imgs.images[0] 248 | 249 | imgs = pipeline_wrap("卡通鸟", 250 | num_inference_steps = 100 251 | ) 252 | imgs.images[0] 253 | 254 | imgs = pipeline_wrap("黄色的球", 255 | num_inference_steps = 100 256 | ) 257 | imgs.images[0] 258 | 259 | imgs = pipeline_wrap("蓝色的龙图", 260 | num_inference_steps = 100 261 | ) 262 | imgs.images[0] 263 | -------------------------------------------------------------------------------- /train_zh_model.py: -------------------------------------------------------------------------------- 1 | #requires_grad_ 2 | REQUIRES_GRAD = False 3 | 4 | import pandas as pd 5 | from collections import namedtuple 6 | 7 | import os 8 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 9 | 10 | import argparse 11 | import logging 12 | import math 13 | import os 14 | import random 15 | from pathlib import Path 16 | from typing import Iterable, Optional 17 | 18 | import numpy as np 19 | import torch 20 | import torch.nn.functional as F 21 | import torch.utils.checkpoint 22 | 23 | from accelerate import Accelerator 24 | from accelerate.logging import get_logger 25 | from accelerate.utils import set_seed 26 | from datasets import load_dataset 27 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel 28 | from diffusers.optimization import get_scheduler 29 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker 30 | from huggingface_hub import HfFolder, Repository, whoami 31 | from torchvision import transforms 32 | from tqdm.auto import tqdm 33 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 34 | 35 | 36 | logger = get_logger(__name__) 37 | 38 | 39 | def parse_args(): 40 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 41 | parser.add_argument( 42 | "--pretrained_model_name_or_path", 43 | type=str, 44 | default=None, 45 | required=True, 46 | help="Path to pretrained model or model identifier from huggingface.co/models.", 47 | ) 48 | parser.add_argument( 49 | "--dataset_name", 50 | type=str, 51 | default=None, 52 | help=( 53 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 54 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 55 | " or to a folder containing files that 🤗 Datasets can understand." 56 | ), 57 | ) 58 | parser.add_argument( 59 | "--dataset_config_name", 60 | type=str, 61 | default=None, 62 | help="The config of the Dataset, leave as None if there's only one config.", 63 | ) 64 | parser.add_argument( 65 | "--train_data_dir", 66 | type=str, 67 | default=None, 68 | help=( 69 | "A folder containing the training data. Folder contents must follow the structure described in" 70 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 71 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 72 | ), 73 | ) 74 | parser.add_argument( 75 | "--image_column", type=str, default="image", help="The column of the dataset containing an image." 76 | ) 77 | parser.add_argument( 78 | "--caption_column", 79 | type=str, 80 | default="text", 81 | help="The column of the dataset containing a caption or a list of captions.", 82 | ) 83 | parser.add_argument( 84 | "--max_train_samples", 85 | type=int, 86 | default=None, 87 | help=( 88 | "For debugging purposes or quicker training, truncate the number of training examples to this " 89 | "value if set." 90 | ), 91 | ) 92 | parser.add_argument( 93 | "--output_dir", 94 | type=str, 95 | default="sd-model-finetuned", 96 | help="The output directory where the model predictions and checkpoints will be written.", 97 | ) 98 | parser.add_argument( 99 | "--cache_dir", 100 | type=str, 101 | default=None, 102 | help="The directory where the downloaded models and datasets will be stored.", 103 | ) 104 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 105 | parser.add_argument( 106 | "--resolution", 107 | type=int, 108 | default=512, 109 | help=( 110 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 111 | " resolution" 112 | ), 113 | ) 114 | parser.add_argument( 115 | "--center_crop", 116 | action="store_true", 117 | help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)", 118 | ) 119 | parser.add_argument( 120 | "--random_flip", 121 | action="store_true", 122 | help="whether to randomly flip images horizontally", 123 | ) 124 | parser.add_argument( 125 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 126 | ) 127 | parser.add_argument("--num_train_epochs", type=int, default=100) 128 | parser.add_argument( 129 | "--max_train_steps", 130 | type=int, 131 | default=None, 132 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 133 | ) 134 | parser.add_argument( 135 | "--gradient_accumulation_steps", 136 | type=int, 137 | default=1, 138 | help="Number of updates steps to accumulate before performing a backward/update pass.", 139 | ) 140 | parser.add_argument( 141 | "--gradient_checkpointing", 142 | action="store_true", 143 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 144 | ) 145 | parser.add_argument( 146 | "--learning_rate", 147 | type=float, 148 | default=1e-4, 149 | help="Initial learning rate (after the potential warmup period) to use.", 150 | ) 151 | parser.add_argument( 152 | "--scale_lr", 153 | action="store_true", 154 | default=False, 155 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 156 | ) 157 | parser.add_argument( 158 | "--lr_scheduler", 159 | type=str, 160 | default="constant", 161 | help=( 162 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 163 | ' "constant", "constant_with_warmup"]' 164 | ), 165 | ) 166 | parser.add_argument( 167 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 168 | ) 169 | parser.add_argument( 170 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 171 | ) 172 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") 173 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 174 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 175 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 176 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 177 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 178 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 179 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 180 | parser.add_argument( 181 | "--hub_model_id", 182 | type=str, 183 | default=None, 184 | help="The name of the repository to keep in sync with the local `output_dir`.", 185 | ) 186 | parser.add_argument( 187 | "--logging_dir", 188 | type=str, 189 | default="logs", 190 | help=( 191 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 192 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 193 | ), 194 | ) 195 | parser.add_argument( 196 | "--mixed_precision", 197 | type=str, 198 | default="no", 199 | choices=["no", "fp16", "bf16"], 200 | help=( 201 | "Whether to use mixed precision. Choose" 202 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 203 | "and an Nvidia Ampere GPU." 204 | ), 205 | ) 206 | parser.add_argument( 207 | "--report_to", 208 | type=str, 209 | default="tensorboard", 210 | help=( 211 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 212 | ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' 213 | "Only applicable when `--with_tracking` is passed." 214 | ), 215 | ) 216 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 217 | 218 | #args = parser.parse_args() 219 | return parser 220 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 221 | if env_local_rank != -1 and env_local_rank != args.local_rank: 222 | args.local_rank = env_local_rank 223 | 224 | # Sanity checks 225 | if args.dataset_name is None and args.train_data_dir is None: 226 | raise ValueError("Need either a dataset name or a training folder.") 227 | 228 | return args 229 | 230 | 231 | def parse_parser_add_arg(parser, as_named_tuple = False): 232 | args_df = pd.DataFrame( 233 | pd.Series(parser.__dict__["_actions"]).map( 234 | lambda x:x.__dict__ 235 | ).values.tolist()) 236 | args_df = args_df.explode("option_strings") 237 | args_df["option_strings"] = args_df["option_strings"].map( 238 | lambda x: x[2:] if x.startswith("--") else x 239 | ).map( 240 | lambda x: x[1:] if x.startswith("-") else x 241 | ) 242 | args_df = args_df[["option_strings", "default"]] 243 | args = dict(args_df.values.tolist()) 244 | if as_named_tuple: 245 | args_parser_namedtuple = namedtuple("args_config", args) 246 | return args_parser_namedtuple(**args) 247 | return args_df 248 | 249 | def transform_named_tuple_to_dict(N_tuple): 250 | return dict(map( 251 | lambda x: (x, getattr(N_tuple, x)) 252 | ,filter(lambda x: not x.startswith("_") ,dir(N_tuple)) 253 | )) 254 | 255 | def transform_dict_to_named_tuple(dict_, name = "NamedTuple"): 256 | args_parser_namedtuple = namedtuple(name, dict_) 257 | return args_parser_namedtuple(**dict_) 258 | 259 | def setattr_gen_option_cls(src_obj): 260 | assert isinstance(src_obj, tuple) or isinstance(src_obj, dict) 261 | if isinstance(src_obj, tuple): 262 | src_obj_ = transform_named_tuple_to_dict(src_obj) 263 | else: 264 | src_obj_ = src_obj 265 | assert isinstance(src_obj_, dict) 266 | class Option(object): 267 | pass 268 | option = Option() 269 | for k, v in src_obj_.items(): 270 | setattr(option, k, v) 271 | return option 272 | 273 | args = parse_args() 274 | args = parse_parser_add_arg(args, as_named_tuple = True) 275 | ''' 276 | export MODEL_NAME="stable-diffusion-v1-4/" 277 | export dataset_name="lambdalabs/pokemon-blip-captions" 278 | 279 | accelerate launch train_text_to_image_ori.py \ 280 | --pretrained_model_name_or_path=$MODEL_NAME \ 281 | --dataset_name=$dataset_name \ 282 | --use_ema \ 283 | --resolution=32 --center_crop --random_flip \ 284 | --train_batch_size=1 \ 285 | --gradient_accumulation_steps=4 \ 286 | --gradient_checkpointing \ 287 | --max_train_steps=15000 \ 288 | --learning_rate=1e-05 \ 289 | --max_grad_norm=1 \ 290 | --lr_scheduler="constant" --lr_warmup_steps=0 \ 291 | --output_dir="sd-pokemon-model" 292 | ''' 293 | args_dict = transform_named_tuple_to_dict(args) 294 | args_dict["pretrained_model_name_or_path"] = "stable-diffusion-v1-4/" 295 | #args_dict["pretrained_model_name_or_path"] = "japanese-stable-diffusion/" 296 | args_dict["dataset_name"] = "svjack/pokemon-blip-captions-en-zh" 297 | args_dict["use_ema"] = True 298 | ###args_dict["use_ema"] = False 299 | args_dict["resolution"] = 256 300 | args_dict["center_crop"] = True 301 | args_dict["random_flip"] = True 302 | args_dict["train_batch_size"] = 1 303 | args_dict["gradient_accumulation_steps"] = 4 304 | args_dict["train_batch_size"] = 4 305 | args_dict["gradient_checkpointing"] = True 306 | #### to 15000 307 | args_dict["max_train_steps"] = 50000 308 | args_dict["learning_rate"] = 1e-05 309 | args_dict["max_grad_norm"] = 1 310 | args_dict["lr_scheduler"] = "constant" 311 | args_dict["lr_warmup_steps"] = 0 312 | args_dict["output_dir"] = "sd-pokemon-model" 313 | args_dict["caption_column"] = "zh_text" 314 | args_dict["mixed_precision"] = "no" 315 | args = transform_dict_to_named_tuple(args_dict) 316 | 317 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 318 | if token is None: 319 | token = HfFolder.get_token() 320 | if organization is None: 321 | username = whoami(token)["name"] 322 | return f"{username}/{model_id}" 323 | else: 324 | return f"{organization}/{model_id}" 325 | 326 | dataset_name_mapping = { 327 | "svjack/pokemon-blip-captions-en-zh": ("image", "zh_text"), 328 | } 329 | 330 | 331 | class EMAModel: 332 | """ 333 | Exponential Moving Average of models weights 334 | """ 335 | 336 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): 337 | parameters = list(parameters) 338 | self.shadow_params = [p.clone().detach() for p in parameters] 339 | 340 | self.decay = decay 341 | self.optimization_step = 0 342 | 343 | def get_decay(self, optimization_step): 344 | """ 345 | Compute the decay factor for the exponential moving average. 346 | """ 347 | value = (1 + optimization_step) / (10 + optimization_step) 348 | return 1 - min(self.decay, value) 349 | 350 | @torch.no_grad() 351 | def step(self, parameters): 352 | parameters = list(parameters) 353 | 354 | self.optimization_step += 1 355 | self.decay = self.get_decay(self.optimization_step) 356 | 357 | for s_param, param in zip(self.shadow_params, parameters): 358 | if param.requires_grad: 359 | tmp = self.decay * (s_param - param) 360 | s_param.sub_(tmp) 361 | else: 362 | s_param.copy_(param) 363 | 364 | torch.cuda.empty_cache() 365 | 366 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: 367 | """ 368 | Copy current averaged parameters into given collection of parameters. 369 | 370 | Args: 371 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 372 | updated with the stored moving averages. If `None`, the 373 | parameters with which this `ExponentialMovingAverage` was 374 | initialized will be used. 375 | """ 376 | parameters = list(parameters) 377 | for s_param, param in zip(self.shadow_params, parameters): 378 | param.data.copy_(s_param.data) 379 | 380 | def to(self, device=None, dtype=None) -> None: 381 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. 382 | 383 | Args: 384 | device: like `device` argument to `torch.Tensor.to` 385 | """ 386 | # .to() on the tensors handles None correctly 387 | self.shadow_params = [ 388 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) 389 | for p in self.shadow_params 390 | ] 391 | 392 | import torch 393 | from transformers import BertForSequenceClassification, BertConfig, BertTokenizer, BertForTokenClassification 394 | from transformers import CLIPProcessor, CLIPModel 395 | import numpy as np 396 | 397 | tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese") 398 | text_encoder = BertForTokenClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-102M-Chinese") 399 | 400 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") 401 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") 402 | 403 | vae.requires_grad_(REQUIRES_GRAD) 404 | text_encoder.requires_grad_(REQUIRES_GRAD) 405 | 406 | if args.gradient_checkpointing: 407 | unet.enable_gradient_checkpointing() 408 | 409 | if args.scale_lr: 410 | args.learning_rate = ( 411 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 412 | ) 413 | 414 | # Initialize the optimizer 415 | if args.use_8bit_adam: 416 | try: 417 | import bitsandbytes as bnb 418 | except ImportError: 419 | raise ImportError( 420 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 421 | ) 422 | 423 | optimizer_cls = bnb.optim.AdamW8bit 424 | else: 425 | optimizer_cls = torch.optim.AdamW 426 | 427 | optimizer = optimizer_cls( 428 | unet.parameters(), 429 | lr=args.learning_rate, 430 | betas=(args.adam_beta1, args.adam_beta2), 431 | weight_decay=args.adam_weight_decay, 432 | eps=args.adam_epsilon, 433 | ) 434 | noise_scheduler = DDPMScheduler( 435 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, 436 | #tensor_format="pt" 437 | ) 438 | 439 | 440 | dataset = load_dataset( 441 | args.dataset_name, 442 | args.dataset_config_name, 443 | cache_dir=args.cache_dir, 444 | ) 445 | column_names = dataset["train"].column_names 446 | dataset_columns = dataset_name_mapping.get(args.dataset_name, None) 447 | if args.image_column is None: 448 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 449 | else: 450 | image_column = args.image_column 451 | if image_column not in column_names: 452 | raise ValueError( 453 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" 454 | ) 455 | if args.caption_column is None: 456 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 457 | else: 458 | caption_column = args.caption_column 459 | if caption_column not in column_names: 460 | raise ValueError( 461 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" 462 | ) 463 | 464 | 465 | # Preprocessing the datasets. 466 | # We need to tokenize input captions and transform the images. 467 | def tokenize_captions(examples, is_train=True): 468 | captions = [] 469 | for caption in examples[caption_column]: 470 | if isinstance(caption, str): 471 | captions.append(caption) 472 | elif isinstance(caption, (list, np.ndarray)): 473 | # take a random caption if there are multiple 474 | captions.append(random.choice(caption) if is_train else caption[0]) 475 | else: 476 | raise ValueError( 477 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 478 | ) 479 | inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True) 480 | input_ids = inputs.input_ids 481 | return input_ids 482 | 483 | 484 | train_transforms = transforms.Compose( 485 | [ 486 | transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), 487 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 488 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), 489 | transforms.ToTensor(), 490 | transforms.Normalize([0.5], [0.5]), 491 | ] 492 | ) 493 | 494 | def preprocess_train(examples): 495 | images = [image.convert("RGB") for image in examples[image_column]] 496 | examples["pixel_values"] = [train_transforms(image) for image in images] 497 | examples["input_ids"] = tokenize_captions(examples) 498 | return examples 499 | 500 | if args.max_train_samples is not None: 501 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 502 | train_dataset = dataset["train"].with_transform(preprocess_train) 503 | 504 | def collate_fn(examples): 505 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 506 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 507 | input_ids = [example["input_ids"] for example in examples] 508 | padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt") 509 | return { 510 | "pixel_values": pixel_values, 511 | "input_ids": padded_tokens.input_ids, 512 | "attention_mask": padded_tokens.attention_mask, 513 | } 514 | 515 | train_dataloader = torch.utils.data.DataLoader( 516 | train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size 517 | ) 518 | 519 | 520 | overrode_max_train_steps = False 521 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 522 | if args.max_train_steps is None: 523 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 524 | overrode_max_train_steps = True 525 | 526 | lr_scheduler = get_scheduler( 527 | args.lr_scheduler, 528 | optimizer=optimizer, 529 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 530 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 531 | ) 532 | 533 | def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): 534 | # Initialize accelerator and tensorboard logging 535 | args = config 536 | unet = model 537 | 538 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 539 | accelerator = Accelerator( 540 | gradient_accumulation_steps=args.gradient_accumulation_steps, 541 | mixed_precision=args.mixed_precision, 542 | log_with=args.report_to, 543 | logging_dir=logging_dir, 544 | ) 545 | logging.basicConfig( 546 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 547 | datefmt="%m/%d/%Y %H:%M:%S", 548 | level=logging.INFO, 549 | ) 550 | if args.seed is not None: 551 | set_seed(args.seed) 552 | if accelerator.is_main_process: 553 | os.makedirs(args.output_dir, exist_ok=True) 554 | 555 | weight_dtype = torch.float32 556 | if args.mixed_precision == "fp16": 557 | weight_dtype = torch.float16 558 | elif args.mixed_precision == "bf16": 559 | weight_dtype = torch.bfloat16 560 | 561 | text_encoder.to(accelerator.device, dtype=weight_dtype) 562 | vae.to(accelerator.device, dtype=weight_dtype) 563 | 564 | if args.use_ema: 565 | ema_unet = EMAModel(unet.parameters()) 566 | ema_unet.to(accelerator.device, dtype=weight_dtype) 567 | 568 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 569 | if overrode_max_train_steps: 570 | args_dict = transform_named_tuple_to_dict(args) 571 | args_dict["max_train_steps"] = args.num_train_epochs * num_update_steps_per_epoch 572 | args = transform_dict_to_named_tuple(args_dict) 573 | 574 | 575 | args_dict = transform_named_tuple_to_dict(args) 576 | args_dict["num_train_epochs"] = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 577 | args = transform_dict_to_named_tuple(args_dict) 578 | 579 | if accelerator.is_main_process: 580 | if config.push_to_hub: 581 | #repo = init_git_repo(config, at_init=True) 582 | pass 583 | accelerator.init_trackers("train_example") 584 | 585 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 586 | 587 | logger.info("***** Running training *****") 588 | logger.info(f" Num examples = {len(train_dataset)}") 589 | logger.info(f" Num Epochs = {args.num_train_epochs}") 590 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 591 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 592 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 593 | logger.info(f" Total optimization steps = {args.max_train_steps}") 594 | 595 | # Prepare everything 596 | # There is no specific order to remember, you just need to unpack the 597 | # objects in the same order you gave them to the prepare method. 598 | ''' 599 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 600 | model, optimizer, train_dataloader, lr_scheduler 601 | ) 602 | ''' 603 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 604 | unet, optimizer, train_dataloader, lr_scheduler 605 | ) 606 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 607 | progress_bar.set_description("Steps") 608 | global_step = 0 609 | 610 | for epoch in range(args.num_train_epochs): 611 | #unet.train() 612 | train_loss = 0.0 613 | for step, batch in enumerate(train_dataloader): 614 | with accelerator.accumulate(unet): 615 | # Convert images to latent space 616 | latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() 617 | latents = latents * 0.18215 618 | 619 | # Sample noise that we'll add to the latents 620 | noise = torch.randn_like(latents) 621 | bsz = latents.shape[0] 622 | # Sample a random timestep for each image 623 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) 624 | timesteps = timesteps.long() 625 | 626 | # Add noise to the latents according to the noise magnitude at each timestep 627 | # (this is the forward diffusion process) 628 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 629 | 630 | # Get the text embedding for conditioning 631 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 632 | #print(encoder_hidden_states.shape) 633 | eh_shape = encoder_hidden_states.shape 634 | eh_pad = torch.zeros((eh_shape[0], eh_shape[1], 768 - 512)) 635 | eh_pad = eh_pad.to(accelerator.device) 636 | encoder_hidden_states = torch.concat([encoder_hidden_states, eh_pad], -1) 637 | 638 | # Predict the noise residual and compute loss 639 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 640 | 641 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 642 | #loss = F.mse_loss(noise.float(), noise.float(), reduction="mean") 643 | 644 | # Gather the losses across all processes for logging (if we use distributed training). 645 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 646 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 647 | 648 | # Backpropagate 649 | accelerator.backward(loss) 650 | if accelerator.sync_gradients: 651 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 652 | optimizer.step() 653 | lr_scheduler.step() 654 | optimizer.zero_grad() 655 | 656 | # Checks if the accelerator has performed an optimization step behind the scenes 657 | if accelerator.sync_gradients: 658 | if args.use_ema: 659 | ema_unet.step(unet.parameters()) 660 | progress_bar.update(1) 661 | global_step += 1 662 | accelerator.log({"train_loss": train_loss}, step=global_step) 663 | train_loss = 0.0 664 | 665 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 666 | progress_bar.set_postfix(**logs) 667 | 668 | if global_step >= args.max_train_steps: 669 | break 670 | 671 | from accelerate import notebook_launcher 672 | 673 | ####args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler) 674 | args_ = (args, unet, noise_scheduler, optimizer, train_dataloader, lr_scheduler) 675 | notebook_launcher(train_loop, args_, num_processes=1) 676 | 677 | #### save to local 678 | import os 679 | save_path = "zh_model_20000" 680 | if not os.path.exists(save_path): 681 | os.mkdir(save_path) 682 | 683 | tokenizer.save_pretrained(os.path.join(save_path, "tokenizer")) 684 | text_encoder.save_pretrained(os.path.join(save_path, "text_encoder")) 685 | vae.save_pretrained(os.path.join(save_path, "vae")) 686 | unet.save_pretrained(os.path.join(save_path, "unet")) 687 | -------------------------------------------------------------------------------- /train_ja_model.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from collections import namedtuple 3 | 4 | import os 5 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 6 | 7 | import argparse 8 | import logging 9 | import math 10 | import os 11 | import random 12 | from pathlib import Path 13 | from typing import Iterable, Optional 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn.functional as F 18 | import torch.utils.checkpoint 19 | 20 | from accelerate import Accelerator 21 | from accelerate.logging import get_logger 22 | from accelerate.utils import set_seed 23 | from datasets import load_dataset 24 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel 25 | from diffusers.optimization import get_scheduler 26 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker 27 | from huggingface_hub import HfFolder, Repository, whoami 28 | from torchvision import transforms 29 | from tqdm.auto import tqdm 30 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 31 | 32 | 33 | logger = get_logger(__name__) 34 | 35 | 36 | def parse_args(): 37 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 38 | parser.add_argument( 39 | "--pretrained_model_name_or_path", 40 | type=str, 41 | default=None, 42 | required=True, 43 | help="Path to pretrained model or model identifier from huggingface.co/models.", 44 | ) 45 | parser.add_argument( 46 | "--dataset_name", 47 | type=str, 48 | default=None, 49 | help=( 50 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 51 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 52 | " or to a folder containing files that 🤗 Datasets can understand." 53 | ), 54 | ) 55 | parser.add_argument( 56 | "--dataset_config_name", 57 | type=str, 58 | default=None, 59 | help="The config of the Dataset, leave as None if there's only one config.", 60 | ) 61 | parser.add_argument( 62 | "--train_data_dir", 63 | type=str, 64 | default=None, 65 | help=( 66 | "A folder containing the training data. Folder contents must follow the structure described in" 67 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 68 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 69 | ), 70 | ) 71 | parser.add_argument( 72 | "--image_column", type=str, default="image", help="The column of the dataset containing an image." 73 | ) 74 | parser.add_argument( 75 | "--caption_column", 76 | type=str, 77 | default="text", 78 | help="The column of the dataset containing a caption or a list of captions.", 79 | ) 80 | parser.add_argument( 81 | "--max_train_samples", 82 | type=int, 83 | default=None, 84 | help=( 85 | "For debugging purposes or quicker training, truncate the number of training examples to this " 86 | "value if set." 87 | ), 88 | ) 89 | parser.add_argument( 90 | "--output_dir", 91 | type=str, 92 | default="sd-model-finetuned", 93 | help="The output directory where the model predictions and checkpoints will be written.", 94 | ) 95 | parser.add_argument( 96 | "--cache_dir", 97 | type=str, 98 | default=None, 99 | help="The directory where the downloaded models and datasets will be stored.", 100 | ) 101 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 102 | parser.add_argument( 103 | "--resolution", 104 | type=int, 105 | default=512, 106 | help=( 107 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 108 | " resolution" 109 | ), 110 | ) 111 | parser.add_argument( 112 | "--center_crop", 113 | action="store_true", 114 | help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)", 115 | ) 116 | parser.add_argument( 117 | "--random_flip", 118 | action="store_true", 119 | help="whether to randomly flip images horizontally", 120 | ) 121 | parser.add_argument( 122 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 123 | ) 124 | parser.add_argument("--num_train_epochs", type=int, default=100) 125 | parser.add_argument( 126 | "--max_train_steps", 127 | type=int, 128 | default=None, 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 130 | ) 131 | parser.add_argument( 132 | "--gradient_accumulation_steps", 133 | type=int, 134 | default=1, 135 | help="Number of updates steps to accumulate before performing a backward/update pass.", 136 | ) 137 | parser.add_argument( 138 | "--gradient_checkpointing", 139 | action="store_true", 140 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 141 | ) 142 | parser.add_argument( 143 | "--learning_rate", 144 | type=float, 145 | default=1e-4, 146 | help="Initial learning rate (after the potential warmup period) to use.", 147 | ) 148 | parser.add_argument( 149 | "--scale_lr", 150 | action="store_true", 151 | default=False, 152 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 153 | ) 154 | parser.add_argument( 155 | "--lr_scheduler", 156 | type=str, 157 | default="constant", 158 | help=( 159 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 160 | ' "constant", "constant_with_warmup"]' 161 | ), 162 | ) 163 | parser.add_argument( 164 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 165 | ) 166 | parser.add_argument( 167 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 168 | ) 169 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") 170 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 171 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 172 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 173 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 174 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 175 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 176 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 177 | parser.add_argument( 178 | "--hub_model_id", 179 | type=str, 180 | default=None, 181 | help="The name of the repository to keep in sync with the local `output_dir`.", 182 | ) 183 | parser.add_argument( 184 | "--logging_dir", 185 | type=str, 186 | default="logs", 187 | help=( 188 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 189 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 190 | ), 191 | ) 192 | parser.add_argument( 193 | "--mixed_precision", 194 | type=str, 195 | default="no", 196 | choices=["no", "fp16", "bf16"], 197 | help=( 198 | "Whether to use mixed precision. Choose" 199 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 200 | "and an Nvidia Ampere GPU." 201 | ), 202 | ) 203 | parser.add_argument( 204 | "--report_to", 205 | type=str, 206 | default="tensorboard", 207 | help=( 208 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 209 | ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' 210 | "Only applicable when `--with_tracking` is passed." 211 | ), 212 | ) 213 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 214 | 215 | #args = parser.parse_args() 216 | return parser 217 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 218 | if env_local_rank != -1 and env_local_rank != args.local_rank: 219 | args.local_rank = env_local_rank 220 | 221 | # Sanity checks 222 | if args.dataset_name is None and args.train_data_dir is None: 223 | raise ValueError("Need either a dataset name or a training folder.") 224 | 225 | return args 226 | 227 | 228 | def parse_parser_add_arg(parser, as_named_tuple = False): 229 | args_df = pd.DataFrame( 230 | pd.Series(parser.__dict__["_actions"]).map( 231 | lambda x:x.__dict__ 232 | ).values.tolist()) 233 | args_df = args_df.explode("option_strings") 234 | args_df["option_strings"] = args_df["option_strings"].map( 235 | lambda x: x[2:] if x.startswith("--") else x 236 | ).map( 237 | lambda x: x[1:] if x.startswith("-") else x 238 | ) 239 | args_df = args_df[["option_strings", "default"]] 240 | args = dict(args_df.values.tolist()) 241 | if as_named_tuple: 242 | args_parser_namedtuple = namedtuple("args_config", args) 243 | return args_parser_namedtuple(**args) 244 | return args_df 245 | 246 | def transform_named_tuple_to_dict(N_tuple): 247 | return dict(map( 248 | lambda x: (x, getattr(N_tuple, x)) 249 | ,filter(lambda x: not x.startswith("_") ,dir(N_tuple)) 250 | )) 251 | 252 | def transform_dict_to_named_tuple(dict_, name = "NamedTuple"): 253 | args_parser_namedtuple = namedtuple(name, dict_) 254 | return args_parser_namedtuple(**dict_) 255 | 256 | def setattr_gen_option_cls(src_obj): 257 | assert isinstance(src_obj, tuple) or isinstance(src_obj, dict) 258 | if isinstance(src_obj, tuple): 259 | src_obj_ = transform_named_tuple_to_dict(src_obj) 260 | else: 261 | src_obj_ = src_obj 262 | assert isinstance(src_obj_, dict) 263 | class Option(object): 264 | pass 265 | option = Option() 266 | for k, v in src_obj_.items(): 267 | setattr(option, k, v) 268 | return option 269 | 270 | args = parse_args() 271 | args = parse_parser_add_arg(args, as_named_tuple = True) 272 | ''' 273 | export MODEL_NAME="stable-diffusion-v1-4/" 274 | export dataset_name="lambdalabs/pokemon-blip-captions" 275 | 276 | accelerate launch train_text_to_image_ori.py \ 277 | --pretrained_model_name_or_path=$MODEL_NAME \ 278 | --dataset_name=$dataset_name \ 279 | --use_ema \ 280 | --resolution=32 --center_crop --random_flip \ 281 | --train_batch_size=1 \ 282 | --gradient_accumulation_steps=4 \ 283 | --gradient_checkpointing \ 284 | --max_train_steps=15000 \ 285 | --learning_rate=1e-05 \ 286 | --max_grad_norm=1 \ 287 | --lr_scheduler="constant" --lr_warmup_steps=0 \ 288 | --output_dir="sd-pokemon-model" 289 | ''' 290 | args_dict = transform_named_tuple_to_dict(args) 291 | #args_dict["pretrained_model_name_or_path"] = "stable-diffusion-v1-4/" 292 | args_dict["pretrained_model_name_or_path"] = "japanese-stable-diffusion/" 293 | args_dict["dataset_name"] = "svjack/pokemon-blip-captions-en-ja" 294 | args_dict["use_ema"] = True 295 | ###args_dict["use_ema"] = False 296 | args_dict["resolution"] = 256 297 | args_dict["center_crop"] = True 298 | args_dict["random_flip"] = True 299 | args_dict["train_batch_size"] = 1 300 | args_dict["gradient_accumulation_steps"] = 4 301 | args_dict["train_batch_size"] = 4 302 | args_dict["gradient_checkpointing"] = True 303 | #### to 15000 304 | args_dict["max_train_steps"] = 50000 305 | args_dict["learning_rate"] = 1e-05 306 | args_dict["max_grad_norm"] = 1 307 | args_dict["lr_scheduler"] = "constant" 308 | args_dict["lr_warmup_steps"] = 0 309 | args_dict["output_dir"] = "sd-pokemon-model" 310 | args_dict["caption_column"] = "ja_text" 311 | args_dict["mixed_precision"] = "no" 312 | args = transform_dict_to_named_tuple(args_dict) 313 | 314 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 315 | if token is None: 316 | token = HfFolder.get_token() 317 | if organization is None: 318 | username = whoami(token)["name"] 319 | return f"{username}/{model_id}" 320 | else: 321 | return f"{organization}/{model_id}" 322 | 323 | dataset_name_mapping = { 324 | "svjack/pokemon-blip-captions-en-ja": ("image", "ja_text"), 325 | } 326 | 327 | 328 | class EMAModel: 329 | """ 330 | Exponential Moving Average of models weights 331 | """ 332 | 333 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): 334 | parameters = list(parameters) 335 | self.shadow_params = [p.clone().detach() for p in parameters] 336 | 337 | self.decay = decay 338 | self.optimization_step = 0 339 | 340 | def get_decay(self, optimization_step): 341 | """ 342 | Compute the decay factor for the exponential moving average. 343 | """ 344 | value = (1 + optimization_step) / (10 + optimization_step) 345 | return 1 - min(self.decay, value) 346 | 347 | @torch.no_grad() 348 | def step(self, parameters): 349 | parameters = list(parameters) 350 | 351 | self.optimization_step += 1 352 | self.decay = self.get_decay(self.optimization_step) 353 | 354 | for s_param, param in zip(self.shadow_params, parameters): 355 | if param.requires_grad: 356 | tmp = self.decay * (s_param - param) 357 | s_param.sub_(tmp) 358 | else: 359 | s_param.copy_(param) 360 | 361 | torch.cuda.empty_cache() 362 | 363 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: 364 | """ 365 | Copy current averaged parameters into given collection of parameters. 366 | 367 | Args: 368 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 369 | updated with the stored moving averages. If `None`, the 370 | parameters with which this `ExponentialMovingAverage` was 371 | initialized will be used. 372 | """ 373 | parameters = list(parameters) 374 | for s_param, param in zip(self.shadow_params, parameters): 375 | param.data.copy_(s_param.data) 376 | 377 | def to(self, device=None, dtype=None) -> None: 378 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. 379 | 380 | Args: 381 | device: like `device` argument to `torch.Tensor.to` 382 | """ 383 | # .to() on the tensors handles None correctly 384 | self.shadow_params = [ 385 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) 386 | for p in self.shadow_params 387 | ] 388 | 389 | #### model prepare 390 | from japanese_stable_diffusion import JapaneseStableDiffusionPipeline 391 | import torch 392 | import pandas as pd 393 | 394 | from torch import autocast 395 | from diffusers import LMSDiscreteScheduler 396 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, 397 | beta_schedule="scaled_linear", num_train_timesteps=1000) 398 | 399 | # pretrained_model_name_or_path = "jap_to_zh_35000" 400 | pretrained_model_name_or_path = "japanese-stable-diffusion/" 401 | pipe = JapaneseStableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, 402 | scheduler=scheduler, use_auth_token=True) 403 | tokenizer, text_encoder, vae, unet = pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet 404 | 405 | vae.requires_grad_(False) 406 | text_encoder.requires_grad_(False) 407 | 408 | if args.gradient_checkpointing: 409 | unet.enable_gradient_checkpointing() 410 | 411 | if args.scale_lr: 412 | args.learning_rate = ( 413 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 414 | ) 415 | 416 | # Initialize the optimizer 417 | if args.use_8bit_adam: 418 | try: 419 | import bitsandbytes as bnb 420 | except ImportError: 421 | raise ImportError( 422 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 423 | ) 424 | 425 | optimizer_cls = bnb.optim.AdamW8bit 426 | else: 427 | optimizer_cls = torch.optim.AdamW 428 | 429 | optimizer = optimizer_cls( 430 | unet.parameters(), 431 | lr=args.learning_rate, 432 | betas=(args.adam_beta1, args.adam_beta2), 433 | weight_decay=args.adam_weight_decay, 434 | eps=args.adam_epsilon, 435 | ) 436 | noise_scheduler = DDPMScheduler( 437 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, 438 | #tensor_format="pt" 439 | ) 440 | 441 | 442 | dataset = load_dataset( 443 | args.dataset_name, 444 | args.dataset_config_name, 445 | cache_dir=args.cache_dir, 446 | ) 447 | column_names = dataset["train"].column_names 448 | dataset_columns = dataset_name_mapping.get(args.dataset_name, None) 449 | if args.image_column is None: 450 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 451 | else: 452 | image_column = args.image_column 453 | if image_column not in column_names: 454 | raise ValueError( 455 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" 456 | ) 457 | if args.caption_column is None: 458 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 459 | else: 460 | caption_column = args.caption_column 461 | if caption_column not in column_names: 462 | raise ValueError( 463 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" 464 | ) 465 | 466 | 467 | # Preprocessing the datasets. 468 | # We need to tokenize input captions and transform the images. 469 | def tokenize_captions(examples, is_train=True): 470 | captions = [] 471 | for caption in examples[caption_column]: 472 | if isinstance(caption, str): 473 | captions.append(caption) 474 | elif isinstance(caption, (list, np.ndarray)): 475 | # take a random caption if there are multiple 476 | captions.append(random.choice(caption) if is_train else caption[0]) 477 | else: 478 | raise ValueError( 479 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 480 | ) 481 | inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True) 482 | input_ids = inputs.input_ids 483 | return input_ids 484 | 485 | 486 | train_transforms = transforms.Compose( 487 | [ 488 | transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), 489 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 490 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), 491 | transforms.ToTensor(), 492 | transforms.Normalize([0.5], [0.5]), 493 | ] 494 | ) 495 | 496 | def preprocess_train(examples): 497 | images = [image.convert("RGB") for image in examples[image_column]] 498 | examples["pixel_values"] = [train_transforms(image) for image in images] 499 | examples["input_ids"] = tokenize_captions(examples) 500 | return examples 501 | 502 | if args.max_train_samples is not None: 503 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 504 | train_dataset = dataset["train"].with_transform(preprocess_train) 505 | 506 | def collate_fn(examples): 507 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 508 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 509 | input_ids = [example["input_ids"] for example in examples] 510 | padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt") 511 | return { 512 | "pixel_values": pixel_values, 513 | "input_ids": padded_tokens.input_ids, 514 | "attention_mask": padded_tokens.attention_mask, 515 | } 516 | 517 | train_dataloader = torch.utils.data.DataLoader( 518 | train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size 519 | ) 520 | 521 | 522 | overrode_max_train_steps = False 523 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 524 | if args.max_train_steps is None: 525 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 526 | overrode_max_train_steps = True 527 | 528 | lr_scheduler = get_scheduler( 529 | args.lr_scheduler, 530 | optimizer=optimizer, 531 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 532 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 533 | ) 534 | 535 | def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): 536 | # Initialize accelerator and tensorboard logging 537 | args = config 538 | unet = model 539 | 540 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 541 | accelerator = Accelerator( 542 | gradient_accumulation_steps=args.gradient_accumulation_steps, 543 | mixed_precision=args.mixed_precision, 544 | log_with=args.report_to, 545 | logging_dir=logging_dir, 546 | ) 547 | logging.basicConfig( 548 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 549 | datefmt="%m/%d/%Y %H:%M:%S", 550 | level=logging.INFO, 551 | ) 552 | if args.seed is not None: 553 | set_seed(args.seed) 554 | if accelerator.is_main_process: 555 | os.makedirs(args.output_dir, exist_ok=True) 556 | 557 | weight_dtype = torch.float32 558 | if args.mixed_precision == "fp16": 559 | weight_dtype = torch.float16 560 | elif args.mixed_precision == "bf16": 561 | weight_dtype = torch.bfloat16 562 | 563 | text_encoder.to(accelerator.device, dtype=weight_dtype) 564 | vae.to(accelerator.device, dtype=weight_dtype) 565 | 566 | if args.use_ema: 567 | ema_unet = EMAModel(unet.parameters()) 568 | ema_unet.to(accelerator.device, dtype=weight_dtype) 569 | 570 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 571 | if overrode_max_train_steps: 572 | args_dict = transform_named_tuple_to_dict(args) 573 | args_dict["max_train_steps"] = args.num_train_epochs * num_update_steps_per_epoch 574 | args = transform_dict_to_named_tuple(args_dict) 575 | 576 | 577 | args_dict = transform_named_tuple_to_dict(args) 578 | args_dict["num_train_epochs"] = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 579 | args = transform_dict_to_named_tuple(args_dict) 580 | 581 | if accelerator.is_main_process: 582 | if config.push_to_hub: 583 | #repo = init_git_repo(config, at_init=True) 584 | pass 585 | accelerator.init_trackers("train_example") 586 | 587 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 588 | 589 | logger.info("***** Running training *****") 590 | logger.info(f" Num examples = {len(train_dataset)}") 591 | logger.info(f" Num Epochs = {args.num_train_epochs}") 592 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 593 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 594 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 595 | logger.info(f" Total optimization steps = {args.max_train_steps}") 596 | 597 | # Prepare everything 598 | # There is no specific order to remember, you just need to unpack the 599 | # objects in the same order you gave them to the prepare method. 600 | ''' 601 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 602 | model, optimizer, train_dataloader, lr_scheduler 603 | ) 604 | ''' 605 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 606 | unet, optimizer, train_dataloader, lr_scheduler 607 | ) 608 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 609 | progress_bar.set_description("Steps") 610 | global_step = 0 611 | 612 | for epoch in range(args.num_train_epochs): 613 | #unet.train() 614 | train_loss = 0.0 615 | for step, batch in enumerate(train_dataloader): 616 | with accelerator.accumulate(unet): 617 | # Convert images to latent space 618 | latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() 619 | latents = latents * 0.18215 620 | 621 | # Sample noise that we'll add to the latents 622 | noise = torch.randn_like(latents) 623 | bsz = latents.shape[0] 624 | # Sample a random timestep for each image 625 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) 626 | timesteps = timesteps.long() 627 | 628 | # Add noise to the latents according to the noise magnitude at each timestep 629 | # (this is the forward diffusion process) 630 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 631 | 632 | # Get the text embedding for conditioning 633 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 634 | 635 | # Predict the noise residual and compute loss 636 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 637 | 638 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 639 | #loss = F.mse_loss(noise.float(), noise.float(), reduction="mean") 640 | 641 | # Gather the losses across all processes for logging (if we use distributed training). 642 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 643 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 644 | 645 | # Backpropagate 646 | accelerator.backward(loss) 647 | if accelerator.sync_gradients: 648 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 649 | optimizer.step() 650 | lr_scheduler.step() 651 | optimizer.zero_grad() 652 | 653 | # Checks if the accelerator has performed an optimization step behind the scenes 654 | if accelerator.sync_gradients: 655 | if args.use_ema: 656 | ema_unet.step(unet.parameters()) 657 | progress_bar.update(1) 658 | global_step += 1 659 | accelerator.log({"train_loss": train_loss}, step=global_step) 660 | train_loss = 0.0 661 | 662 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 663 | progress_bar.set_postfix(**logs) 664 | 665 | if global_step >= args.max_train_steps: 666 | break 667 | 668 | from accelerate import notebook_launcher 669 | 670 | #### train it. 671 | args_ = (args, unet, noise_scheduler, optimizer, train_dataloader, lr_scheduler) 672 | notebook_launcher(train_loop, args_, num_processes=1) 673 | 674 | #### save to local 675 | pipeline = StableDiffusionPipeline( 676 | text_encoder=text_encoder, 677 | vae=vae, 678 | unet=unet, 679 | tokenizer=tokenizer, 680 | scheduler=scheduler, 681 | safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), 682 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), 683 | ) 684 | pipeline.save_pretrained("jap_model_50000") 685 | pipeline.safety_checker = lambda images, clip_input: (images, False) 686 | pipeline = pipeline.to("cuda") 687 | 688 | imgs = pipeline("鉢植えの植物を頭に載せた漫画のキャラクター", 689 | num_inference_steps = 100 690 | ) 691 | imgs.images[0] 692 | 693 | imgs = pipeline("漫画の鳥", 694 | num_inference_steps = 100 695 | ) 696 | imgs.images[0] 697 | 698 | imgs = pipeline("黄色いボール", 699 | num_inference_steps = 100 700 | ) 701 | imgs.images[0] 702 | 703 | imgs = pipeline("ブルードラゴンのイラスト", 704 | num_inference_steps = 100 705 | ) 706 | imgs.images[0] 707 | -------------------------------------------------------------------------------- /train_en_model.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from collections import namedtuple 3 | 4 | import os 5 | os.environ['KMP_DUPLICATE_LIB_OK']='True' 6 | 7 | import argparse 8 | import logging 9 | import math 10 | import os 11 | import random 12 | from pathlib import Path 13 | from typing import Iterable, Optional 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn.functional as F 18 | import torch.utils.checkpoint 19 | 20 | from accelerate import Accelerator 21 | from accelerate.logging import get_logger 22 | from accelerate.utils import set_seed 23 | from datasets import load_dataset 24 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel 25 | from diffusers.optimization import get_scheduler 26 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker 27 | from huggingface_hub import HfFolder, Repository, whoami 28 | from torchvision import transforms 29 | from tqdm.auto import tqdm 30 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 31 | 32 | 33 | logger = get_logger(__name__) 34 | 35 | 36 | def parse_args(): 37 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 38 | parser.add_argument( 39 | "--pretrained_model_name_or_path", 40 | type=str, 41 | default=None, 42 | required=True, 43 | help="Path to pretrained model or model identifier from huggingface.co/models.", 44 | ) 45 | parser.add_argument( 46 | "--dataset_name", 47 | type=str, 48 | default=None, 49 | help=( 50 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 51 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 52 | " or to a folder containing files that 🤗 Datasets can understand." 53 | ), 54 | ) 55 | parser.add_argument( 56 | "--dataset_config_name", 57 | type=str, 58 | default=None, 59 | help="The config of the Dataset, leave as None if there's only one config.", 60 | ) 61 | parser.add_argument( 62 | "--train_data_dir", 63 | type=str, 64 | default=None, 65 | help=( 66 | "A folder containing the training data. Folder contents must follow the structure described in" 67 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 68 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 69 | ), 70 | ) 71 | parser.add_argument( 72 | "--image_column", type=str, default="image", help="The column of the dataset containing an image." 73 | ) 74 | parser.add_argument( 75 | "--caption_column", 76 | type=str, 77 | default="text", 78 | help="The column of the dataset containing a caption or a list of captions.", 79 | ) 80 | parser.add_argument( 81 | "--max_train_samples", 82 | type=int, 83 | default=None, 84 | help=( 85 | "For debugging purposes or quicker training, truncate the number of training examples to this " 86 | "value if set." 87 | ), 88 | ) 89 | parser.add_argument( 90 | "--output_dir", 91 | type=str, 92 | default="sd-model-finetuned", 93 | help="The output directory where the model predictions and checkpoints will be written.", 94 | ) 95 | parser.add_argument( 96 | "--cache_dir", 97 | type=str, 98 | default=None, 99 | help="The directory where the downloaded models and datasets will be stored.", 100 | ) 101 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 102 | parser.add_argument( 103 | "--resolution", 104 | type=int, 105 | default=512, 106 | help=( 107 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 108 | " resolution" 109 | ), 110 | ) 111 | parser.add_argument( 112 | "--center_crop", 113 | action="store_true", 114 | help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)", 115 | ) 116 | parser.add_argument( 117 | "--random_flip", 118 | action="store_true", 119 | help="whether to randomly flip images horizontally", 120 | ) 121 | parser.add_argument( 122 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 123 | ) 124 | parser.add_argument("--num_train_epochs", type=int, default=100) 125 | parser.add_argument( 126 | "--max_train_steps", 127 | type=int, 128 | default=None, 129 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 130 | ) 131 | parser.add_argument( 132 | "--gradient_accumulation_steps", 133 | type=int, 134 | default=1, 135 | help="Number of updates steps to accumulate before performing a backward/update pass.", 136 | ) 137 | parser.add_argument( 138 | "--gradient_checkpointing", 139 | action="store_true", 140 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 141 | ) 142 | parser.add_argument( 143 | "--learning_rate", 144 | type=float, 145 | default=1e-4, 146 | help="Initial learning rate (after the potential warmup period) to use.", 147 | ) 148 | parser.add_argument( 149 | "--scale_lr", 150 | action="store_true", 151 | default=False, 152 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 153 | ) 154 | parser.add_argument( 155 | "--lr_scheduler", 156 | type=str, 157 | default="constant", 158 | help=( 159 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 160 | ' "constant", "constant_with_warmup"]' 161 | ), 162 | ) 163 | parser.add_argument( 164 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 165 | ) 166 | parser.add_argument( 167 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 168 | ) 169 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") 170 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 171 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 172 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 173 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 174 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 175 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 176 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 177 | parser.add_argument( 178 | "--hub_model_id", 179 | type=str, 180 | default=None, 181 | help="The name of the repository to keep in sync with the local `output_dir`.", 182 | ) 183 | parser.add_argument( 184 | "--logging_dir", 185 | type=str, 186 | default="logs", 187 | help=( 188 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 189 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 190 | ), 191 | ) 192 | parser.add_argument( 193 | "--mixed_precision", 194 | type=str, 195 | default="no", 196 | choices=["no", "fp16", "bf16"], 197 | help=( 198 | "Whether to use mixed precision. Choose" 199 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 200 | "and an Nvidia Ampere GPU." 201 | ), 202 | ) 203 | parser.add_argument( 204 | "--report_to", 205 | type=str, 206 | default="tensorboard", 207 | help=( 208 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 209 | ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' 210 | "Only applicable when `--with_tracking` is passed." 211 | ), 212 | ) 213 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 214 | 215 | #args = parser.parse_args() 216 | return parser 217 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 218 | if env_local_rank != -1 and env_local_rank != args.local_rank: 219 | args.local_rank = env_local_rank 220 | 221 | # Sanity checks 222 | if args.dataset_name is None and args.train_data_dir is None: 223 | raise ValueError("Need either a dataset name or a training folder.") 224 | 225 | return args 226 | 227 | 228 | def parse_parser_add_arg(parser, as_named_tuple = False): 229 | args_df = pd.DataFrame( 230 | pd.Series(parser.__dict__["_actions"]).map( 231 | lambda x:x.__dict__ 232 | ).values.tolist()) 233 | args_df = args_df.explode("option_strings") 234 | args_df["option_strings"] = args_df["option_strings"].map( 235 | lambda x: x[2:] if x.startswith("--") else x 236 | ).map( 237 | lambda x: x[1:] if x.startswith("-") else x 238 | ) 239 | args_df = args_df[["option_strings", "default"]] 240 | args = dict(args_df.values.tolist()) 241 | if as_named_tuple: 242 | args_parser_namedtuple = namedtuple("args_config", args) 243 | return args_parser_namedtuple(**args) 244 | return args_df 245 | 246 | def transform_named_tuple_to_dict(N_tuple): 247 | return dict(map( 248 | lambda x: (x, getattr(N_tuple, x)) 249 | ,filter(lambda x: not x.startswith("_") ,dir(N_tuple)) 250 | )) 251 | 252 | def transform_dict_to_named_tuple(dict_, name = "NamedTuple"): 253 | args_parser_namedtuple = namedtuple(name, dict_) 254 | return args_parser_namedtuple(**dict_) 255 | 256 | def setattr_gen_option_cls(src_obj): 257 | assert isinstance(src_obj, tuple) or isinstance(src_obj, dict) 258 | if isinstance(src_obj, tuple): 259 | src_obj_ = transform_named_tuple_to_dict(src_obj) 260 | else: 261 | src_obj_ = src_obj 262 | assert isinstance(src_obj_, dict) 263 | class Option(object): 264 | pass 265 | option = Option() 266 | for k, v in src_obj_.items(): 267 | setattr(option, k, v) 268 | return option 269 | 270 | args = parse_args() 271 | args = parse_parser_add_arg(args, as_named_tuple = True) 272 | ''' 273 | export MODEL_NAME="stable-diffusion-v1-4/" 274 | export dataset_name="lambdalabs/pokemon-blip-captions" 275 | 276 | accelerate launch train_text_to_image_ori.py \ 277 | --pretrained_model_name_or_path=$MODEL_NAME \ 278 | --dataset_name=$dataset_name \ 279 | --use_ema \ 280 | --resolution=32 --center_crop --random_flip \ 281 | --train_batch_size=1 \ 282 | --gradient_accumulation_steps=4 \ 283 | --gradient_checkpointing \ 284 | --max_train_steps=15000 \ 285 | --learning_rate=1e-05 \ 286 | --max_grad_norm=1 \ 287 | --lr_scheduler="constant" --lr_warmup_steps=0 \ 288 | --output_dir="sd-pokemon-model" 289 | ''' 290 | args_dict = transform_named_tuple_to_dict(args) 291 | args_dict["pretrained_model_name_or_path"] = "stable-diffusion-v1-4/" 292 | #args_dict["pretrained_model_name_or_path"] = "japanese-stable-diffusion/" 293 | args_dict["dataset_name"] = "svjack/pokemon-blip-captions-en-ja" 294 | args_dict["use_ema"] = True 295 | ###args_dict["use_ema"] = False 296 | args_dict["resolution"] = 256 297 | args_dict["center_crop"] = True 298 | args_dict["random_flip"] = True 299 | args_dict["train_batch_size"] = 1 300 | args_dict["gradient_accumulation_steps"] = 4 301 | args_dict["train_batch_size"] = 4 302 | args_dict["gradient_checkpointing"] = True 303 | #### to 15000 304 | args_dict["max_train_steps"] = 50000 305 | args_dict["learning_rate"] = 1e-05 306 | args_dict["max_grad_norm"] = 1 307 | args_dict["lr_scheduler"] = "constant" 308 | args_dict["lr_warmup_steps"] = 0 309 | args_dict["output_dir"] = "sd-pokemon-model" 310 | args_dict["caption_column"] = "en_text" 311 | args_dict["mixed_precision"] = "no" 312 | args = transform_dict_to_named_tuple(args_dict) 313 | 314 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 315 | if token is None: 316 | token = HfFolder.get_token() 317 | if organization is None: 318 | username = whoami(token)["name"] 319 | return f"{username}/{model_id}" 320 | else: 321 | return f"{organization}/{model_id}" 322 | 323 | dataset_name_mapping = { 324 | "svjack/pokemon-blip-captions-en-ja": ("image", "en_text"), 325 | } 326 | 327 | 328 | class EMAModel: 329 | """ 330 | Exponential Moving Average of models weights 331 | """ 332 | 333 | def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): 334 | parameters = list(parameters) 335 | self.shadow_params = [p.clone().detach() for p in parameters] 336 | 337 | self.decay = decay 338 | self.optimization_step = 0 339 | 340 | def get_decay(self, optimization_step): 341 | """ 342 | Compute the decay factor for the exponential moving average. 343 | """ 344 | value = (1 + optimization_step) / (10 + optimization_step) 345 | return 1 - min(self.decay, value) 346 | 347 | @torch.no_grad() 348 | def step(self, parameters): 349 | parameters = list(parameters) 350 | 351 | self.optimization_step += 1 352 | self.decay = self.get_decay(self.optimization_step) 353 | 354 | for s_param, param in zip(self.shadow_params, parameters): 355 | if param.requires_grad: 356 | tmp = self.decay * (s_param - param) 357 | s_param.sub_(tmp) 358 | else: 359 | s_param.copy_(param) 360 | 361 | torch.cuda.empty_cache() 362 | 363 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: 364 | """ 365 | Copy current averaged parameters into given collection of parameters. 366 | 367 | Args: 368 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 369 | updated with the stored moving averages. If `None`, the 370 | parameters with which this `ExponentialMovingAverage` was 371 | initialized will be used. 372 | """ 373 | parameters = list(parameters) 374 | for s_param, param in zip(self.shadow_params, parameters): 375 | param.data.copy_(s_param.data) 376 | 377 | def to(self, device=None, dtype=None) -> None: 378 | r"""Move internal buffers of the ExponentialMovingAverage to `device`. 379 | 380 | Args: 381 | device: like `device` argument to `torch.Tensor.to` 382 | """ 383 | # .to() on the tensors handles None correctly 384 | self.shadow_params = [ 385 | p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) 386 | for p in self.shadow_params 387 | ] 388 | 389 | #### model prepare 390 | ''' 391 | from japanese_stable_diffusion import JapaneseStableDiffusionPipeline 392 | import torch 393 | import pandas as pd 394 | 395 | from torch import autocast 396 | from diffusers import LMSDiscreteScheduler 397 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, 398 | beta_schedule="scaled_linear", num_train_timesteps=1000) 399 | 400 | # pretrained_model_name_or_path = "jap_to_zh_35000" 401 | pretrained_model_name_or_path = "japanese-stable-diffusion/" 402 | pipe = JapaneseStableDiffusionPipeline.from_pretrained(pretrained_model_name_or_path, 403 | scheduler=scheduler, use_auth_token=True) 404 | tokenizer, text_encoder, vae, unet = pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet 405 | ''' 406 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 407 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") 408 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") 409 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") 410 | 411 | vae.requires_grad_(False) 412 | text_encoder.requires_grad_(False) 413 | 414 | if args.gradient_checkpointing: 415 | unet.enable_gradient_checkpointing() 416 | 417 | if args.scale_lr: 418 | args.learning_rate = ( 419 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 420 | ) 421 | 422 | # Initialize the optimizer 423 | if args.use_8bit_adam: 424 | try: 425 | import bitsandbytes as bnb 426 | except ImportError: 427 | raise ImportError( 428 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 429 | ) 430 | 431 | optimizer_cls = bnb.optim.AdamW8bit 432 | else: 433 | optimizer_cls = torch.optim.AdamW 434 | 435 | optimizer = optimizer_cls( 436 | unet.parameters(), 437 | lr=args.learning_rate, 438 | betas=(args.adam_beta1, args.adam_beta2), 439 | weight_decay=args.adam_weight_decay, 440 | eps=args.adam_epsilon, 441 | ) 442 | noise_scheduler = DDPMScheduler( 443 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, 444 | #tensor_format="pt" 445 | ) 446 | 447 | 448 | dataset = load_dataset( 449 | args.dataset_name, 450 | args.dataset_config_name, 451 | cache_dir=args.cache_dir, 452 | ) 453 | column_names = dataset["train"].column_names 454 | dataset_columns = dataset_name_mapping.get(args.dataset_name, None) 455 | if args.image_column is None: 456 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 457 | else: 458 | image_column = args.image_column 459 | if image_column not in column_names: 460 | raise ValueError( 461 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" 462 | ) 463 | if args.caption_column is None: 464 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 465 | else: 466 | caption_column = args.caption_column 467 | if caption_column not in column_names: 468 | raise ValueError( 469 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" 470 | ) 471 | 472 | 473 | # Preprocessing the datasets. 474 | # We need to tokenize input captions and transform the images. 475 | def tokenize_captions(examples, is_train=True): 476 | captions = [] 477 | for caption in examples[caption_column]: 478 | if isinstance(caption, str): 479 | captions.append(caption) 480 | elif isinstance(caption, (list, np.ndarray)): 481 | # take a random caption if there are multiple 482 | captions.append(random.choice(caption) if is_train else caption[0]) 483 | else: 484 | raise ValueError( 485 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 486 | ) 487 | inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True) 488 | input_ids = inputs.input_ids 489 | return input_ids 490 | 491 | 492 | train_transforms = transforms.Compose( 493 | [ 494 | transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), 495 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 496 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), 497 | transforms.ToTensor(), 498 | transforms.Normalize([0.5], [0.5]), 499 | ] 500 | ) 501 | 502 | def preprocess_train(examples): 503 | images = [image.convert("RGB") for image in examples[image_column]] 504 | examples["pixel_values"] = [train_transforms(image) for image in images] 505 | examples["input_ids"] = tokenize_captions(examples) 506 | return examples 507 | 508 | if args.max_train_samples is not None: 509 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 510 | train_dataset = dataset["train"].with_transform(preprocess_train) 511 | 512 | def collate_fn(examples): 513 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 514 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 515 | input_ids = [example["input_ids"] for example in examples] 516 | padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt") 517 | return { 518 | "pixel_values": pixel_values, 519 | "input_ids": padded_tokens.input_ids, 520 | "attention_mask": padded_tokens.attention_mask, 521 | } 522 | 523 | train_dataloader = torch.utils.data.DataLoader( 524 | train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size 525 | ) 526 | 527 | 528 | overrode_max_train_steps = False 529 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 530 | if args.max_train_steps is None: 531 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 532 | overrode_max_train_steps = True 533 | 534 | lr_scheduler = get_scheduler( 535 | args.lr_scheduler, 536 | optimizer=optimizer, 537 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 538 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 539 | ) 540 | 541 | def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler): 542 | # Initialize accelerator and tensorboard logging 543 | args = config 544 | unet = model 545 | 546 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 547 | accelerator = Accelerator( 548 | gradient_accumulation_steps=args.gradient_accumulation_steps, 549 | mixed_precision=args.mixed_precision, 550 | log_with=args.report_to, 551 | logging_dir=logging_dir, 552 | ) 553 | logging.basicConfig( 554 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 555 | datefmt="%m/%d/%Y %H:%M:%S", 556 | level=logging.INFO, 557 | ) 558 | if args.seed is not None: 559 | set_seed(args.seed) 560 | if accelerator.is_main_process: 561 | os.makedirs(args.output_dir, exist_ok=True) 562 | 563 | weight_dtype = torch.float32 564 | if args.mixed_precision == "fp16": 565 | weight_dtype = torch.float16 566 | elif args.mixed_precision == "bf16": 567 | weight_dtype = torch.bfloat16 568 | 569 | text_encoder.to(accelerator.device, dtype=weight_dtype) 570 | vae.to(accelerator.device, dtype=weight_dtype) 571 | 572 | if args.use_ema: 573 | ema_unet = EMAModel(unet.parameters()) 574 | ema_unet.to(accelerator.device, dtype=weight_dtype) 575 | 576 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 577 | if overrode_max_train_steps: 578 | args_dict = transform_named_tuple_to_dict(args) 579 | args_dict["max_train_steps"] = args.num_train_epochs * num_update_steps_per_epoch 580 | args = transform_dict_to_named_tuple(args_dict) 581 | 582 | 583 | args_dict = transform_named_tuple_to_dict(args) 584 | args_dict["num_train_epochs"] = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 585 | args = transform_dict_to_named_tuple(args_dict) 586 | 587 | if accelerator.is_main_process: 588 | if config.push_to_hub: 589 | #repo = init_git_repo(config, at_init=True) 590 | pass 591 | accelerator.init_trackers("train_example") 592 | 593 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 594 | 595 | logger.info("***** Running training *****") 596 | logger.info(f" Num examples = {len(train_dataset)}") 597 | logger.info(f" Num Epochs = {args.num_train_epochs}") 598 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 599 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 600 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 601 | logger.info(f" Total optimization steps = {args.max_train_steps}") 602 | 603 | # Prepare everything 604 | # There is no specific order to remember, you just need to unpack the 605 | # objects in the same order you gave them to the prepare method. 606 | ''' 607 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 608 | model, optimizer, train_dataloader, lr_scheduler 609 | ) 610 | ''' 611 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 612 | unet, optimizer, train_dataloader, lr_scheduler 613 | ) 614 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 615 | progress_bar.set_description("Steps") 616 | global_step = 0 617 | 618 | for epoch in range(args.num_train_epochs): 619 | #unet.train() 620 | train_loss = 0.0 621 | for step, batch in enumerate(train_dataloader): 622 | with accelerator.accumulate(unet): 623 | # Convert images to latent space 624 | latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() 625 | latents = latents * 0.18215 626 | 627 | # Sample noise that we'll add to the latents 628 | noise = torch.randn_like(latents) 629 | bsz = latents.shape[0] 630 | # Sample a random timestep for each image 631 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) 632 | timesteps = timesteps.long() 633 | 634 | # Add noise to the latents according to the noise magnitude at each timestep 635 | # (this is the forward diffusion process) 636 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 637 | 638 | # Get the text embedding for conditioning 639 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 640 | 641 | # Predict the noise residual and compute loss 642 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 643 | 644 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 645 | #loss = F.mse_loss(noise.float(), noise.float(), reduction="mean") 646 | 647 | # Gather the losses across all processes for logging (if we use distributed training). 648 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 649 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 650 | 651 | # Backpropagate 652 | accelerator.backward(loss) 653 | if accelerator.sync_gradients: 654 | accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm) 655 | optimizer.step() 656 | lr_scheduler.step() 657 | optimizer.zero_grad() 658 | 659 | # Checks if the accelerator has performed an optimization step behind the scenes 660 | if accelerator.sync_gradients: 661 | if args.use_ema: 662 | ema_unet.step(unet.parameters()) 663 | progress_bar.update(1) 664 | global_step += 1 665 | accelerator.log({"train_loss": train_loss}, step=global_step) 666 | train_loss = 0.0 667 | 668 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 669 | progress_bar.set_postfix(**logs) 670 | 671 | if global_step >= args.max_train_steps: 672 | break 673 | 674 | from accelerate import notebook_launcher 675 | 676 | #### train it. 677 | args_ = (args, unet, noise_scheduler, optimizer, train_dataloader, lr_scheduler) 678 | notebook_launcher(train_loop, args_, num_processes=1) 679 | 680 | from diffusers import LMSDiscreteScheduler 681 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, 682 | beta_schedule="scaled_linear", num_train_timesteps=1000) 683 | 684 | #### save to local 685 | pipeline = StableDiffusionPipeline( 686 | text_encoder=text_encoder, 687 | vae=vae, 688 | unet=unet, 689 | tokenizer=tokenizer, 690 | scheduler=scheduler, 691 | safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), 692 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), 693 | ) 694 | pipeline.save_pretrained("en_model_26000") 695 | pipeline.safety_checker = lambda images, clip_input: (images, False) 696 | pipeline = pipeline.to("cuda") 697 | 698 | 699 | imgs = pipeline("A cartoon character with a potted plant on his head", 700 | num_inference_steps = 100 701 | ) 702 | imgs.images[0] 703 | 704 | imgs = pipeline("cartoon bird", 705 | num_inference_steps = 100 706 | ) 707 | imgs.images[0] 708 | 709 | imgs = pipeline("yellow ball", 710 | num_inference_steps = 100 711 | ) 712 | imgs.images[0] 713 | 714 | imgs = pipeline("blue dragon illustration", 715 | num_inference_steps = 100 716 | ) 717 | imgs.images[0] 718 | --------------------------------------------------------------------------------