├── .github └── FUNDING.yml ├── LICENSE ├── README.md ├── prepare_dataset_colab.ipynb ├── push.py └── train_text_to_image_flax.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] 4 | patreon: camenduru 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: camenduru 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 🐣 Please follow me for new updates https://twitter.com/camenduru
2 | 🔥 Please join our discord server https://discord.gg/k5BwmmvJJU
3 | 🥳 Please join my patreon community https://patreon.com/camenduru
4 | 5 | This repo contains all codes and commands used in `train text to image with tpu tutorial` https://youtu.be/NGta-t4BoLY 6 | 7 | ## Prepare TPU VM 8 | use with 🐧 linux or linux inside windows (wsl) 9 | 10 | https://console.cloud.google.com 11 | 12 | ```sh 13 | echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | sudo tee -a /etc/apt/sources.list.d/google-cloud-sdk.list 14 | curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - 15 | sudo apt-get update && sudo apt-get install google-cloud-cli 16 | gcloud init 17 | ``` 18 | 19 | ```sh 20 | gcloud alpha compute tpus tpu-vm ssh node-1 --zone us-central1-f 21 | ``` 22 | 23 | ```py 24 | pip install -U zipp "jax[tpu]==0.3.23" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html packaging flax==0.6.1 chex==0.1.5 orbax==0.0.13 numpy diffusers==0.10.0 transformers piexif fold_to_ascii discord ftfy dill urllib3 datasets importlib-metadata accelerate==0.16.0 OmegaConf wandb==0.13.4 optax torch torchvision modelcards pytorch_lightning protobuf==3.20.* tensorboard markupsafe==2.0.1 gradio 25 | 26 | sudo apt install git-lfs 27 | ``` 28 | 29 | ```sh 30 | mkdir tpu 31 | cd tpu 32 | mkdir train 33 | cd train 34 | ``` 35 | 36 | ```sh 37 | wget https://raw.githubusercontent.com/camenduru/train-text-to-image-tpu-tutorial/main/train_text_to_image_flax.py 38 | ``` 39 | 40 | ```sh 41 | tmux 42 | ``` 43 | 44 | ## Prepare Dataset 45 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/train-text-to-image-tpu-tutorial/blob/main/prepare_dataset_colab.ipynb) 46 | 47 | 48 | ## Train 49 | 50 | ```py 51 | python3 train_text_to_image_flax.py \ 52 | --pretrained_model_name_or_path="flax/sd15-non-ema" \ 53 | --dataset_name="camenduru/test" \ 54 | --resolution=512 \ 55 | --mixed_precision="bf16" \ 56 | --train_batch_size=1 \ 57 | --num_train_epochs=650 \ 58 | --learning_rate=1e-6 \ 59 | --max_grad_norm=1 \ 60 | --output_dir="test" \ 61 | --report_to="wandb" 62 | ``` 63 | 64 | ## Push Trained model to 🤗 65 | 66 | ```sh 67 | wget https://raw.githubusercontent.com/camenduru/train-text-to-image-tpu-tutorial/main/push.py 68 | python3 push.py 69 | ``` 70 | 71 | ## Convert Flax model to PyTorch 72 | https://huggingface.co/spaces/camenduru/converter 73 | 74 | ## Test Flax model or PyTorch model 75 | https://github.com/camenduru/stable-diffusion-diffusers-colab 76 | 77 | ## Outputs 78 | https://huggingface.co/camenduru/tpu-train-tutorial-flax
79 | https://huggingface.co/camenduru/tpu-train-tutorial-pt 80 | 81 | ## YouTube Live VOD 82 | https://youtu.be/NGta-t4BoLY 83 | 84 | ## Main Repo ♥ 85 | https://github.com/huggingface/diffusers 86 | 87 | ## Scripts From ♥ 88 | https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_flax.py 89 | -------------------------------------------------------------------------------- /prepare_dataset_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github" 7 | }, 8 | "source": [ 9 | "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/train-text-to-image-tpu-tutorial/blob/main/prepare_dataset_colab.ipynb)" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": { 15 | "id": "3QARKc9y2c8e" 16 | }, 17 | "source": [ 18 | "`create test.txt and copy 🍝 this html inside`\n", 19 | "```html\n", 20 | "\"cat\"\n", 21 | "\"cat\"\n", 22 | "\"cat\"\n", 23 | "\"cat\"\n", 24 | "\"cat\"\n", 25 | "\"cat\"\n", 26 | "\"cat\"\n", 27 | "\"cat\"\n", 28 | "```\n", 29 | "\n" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": { 36 | "id": "FvYV6MbSdf40" 37 | }, 38 | "outputs": [], 39 | "source": [ 40 | "!pip install datasets bs4" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "metadata": { 47 | "id": "i3s_ZLPUdk9a" 48 | }, 49 | "outputs": [], 50 | "source": [ 51 | "from huggingface_hub import notebook_login\n", 52 | "!git config --global credential.helper store\n", 53 | "notebook_login()" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "metadata": { 60 | "id": "e8ZYugchdl_w" 61 | }, 62 | "outputs": [], 63 | "source": [ 64 | "!mkdir test" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": null, 70 | "metadata": { 71 | "id": "AndErT11dneF" 72 | }, 73 | "outputs": [], 74 | "source": [ 75 | "import urllib.request, requests\n", 76 | "from bs4 import BeautifulSoup\n", 77 | "\n", 78 | "with open('/content/test.txt') as html:\n", 79 | " content = html.read()\n", 80 | "\n", 81 | "soup = BeautifulSoup(content)\n", 82 | "for imgtag in soup.find_all('img'):\n", 83 | " url=imgtag['src']\n", 84 | " name = url.split('/')[-1]\n", 85 | " headers={'user-agent': 'Mozilla/5.0'}\n", 86 | " r=requests.get(url, headers=headers)\n", 87 | " with open(f\"/content/test/{name}\", 'wb') as f:\n", 88 | " f.write(r.content)" 89 | ] 90 | }, 91 | { 92 | "cell_type": "code", 93 | "execution_count": null, 94 | "metadata": { 95 | "id": "TaDtKFqCdo4b" 96 | }, 97 | "outputs": [], 98 | "source": [ 99 | "from datasets import load_dataset, Dataset, Image\n", 100 | "from bs4 import BeautifulSoup\n", 101 | "\n", 102 | "with open('/content/test.txt') as html:\n", 103 | " content = html.read()\n", 104 | "\n", 105 | "texts = []\n", 106 | "images = []\n", 107 | "soup = BeautifulSoup(content)\n", 108 | "for imgtag in soup.find_all('img'):\n", 109 | " texts.append(imgtag['alt'])\n", 110 | " images.append(f\"/content/test/{imgtag['src'].split('/')[-1]}\")\n", 111 | " \n", 112 | "ds = Dataset.from_dict({\"image\": images, \"text\": texts})\n", 113 | "ds = ds.cast_column(\"image\", Image())" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": { 120 | "id": "1AgjVY1pdqH5" 121 | }, 122 | "outputs": [], 123 | "source": [ 124 | "ds.push_to_hub(\"camenduru/test\")" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": null, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "# for unzipped images\n", 134 | "import os\n", 135 | "directory = '/content/images'\n", 136 | "image_list = []\n", 137 | "for subdir in os.listdir(directory):\n", 138 | " subdir_path = os.path.join(directory, subdir)\n", 139 | " if os.path.isdir(subdir_path):\n", 140 | " for filename in os.listdir(subdir_path):\n", 141 | " if filename.endswith(('.jpg', '.jpeg', '.png', '.gif')):\n", 142 | " image_list.append(f'\"{subdir}\"')\n", 143 | "output_file = 'image_list.txt'\n", 144 | "with open(output_file, 'w') as f:\n", 145 | " for image_tag in image_list:\n", 146 | " f.write(image_tag + '\\n')" 147 | ] 148 | } 149 | ], 150 | "metadata": { 151 | "colab": { 152 | "provenance": [] 153 | }, 154 | "kernelspec": { 155 | "display_name": "Python 3", 156 | "name": "python3" 157 | }, 158 | "language_info": { 159 | "name": "python" 160 | } 161 | }, 162 | "nbformat": 4, 163 | "nbformat_minor": 0 164 | } 165 | -------------------------------------------------------------------------------- /push.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import os 3 | from huggingface_hub import create_repo, upload_folder 4 | 5 | block = gr.Blocks() 6 | 7 | def build(hf_token): 8 | repo_id = "camenduru/tpu-train-tutorial-flax" 9 | path_in_repo = "" 10 | create_repo(repo_id, private=True, token=hf_token) 11 | upload_folder(folder_path="/home/camenduru/tpu/train/test", path_in_repo=path_in_repo, repo_id=repo_id, commit_message=f"train", token=hf_token) 12 | return "done" 13 | 14 | def init(): 15 | with block: 16 | hf_token = gr.Textbox(show_label=False, max_lines=1, placeholder="🤗 token") 17 | out = gr.Textbox(show_label=False) 18 | btn = gr.Button("Push to 🤗") 19 | btn.click(build, inputs=hf_token, outputs=out) 20 | block.launch(share=True) 21 | 22 | if __name__ == "__main__": 23 | init() 24 | -------------------------------------------------------------------------------- /train_text_to_image_flax.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import math 4 | import os 5 | import random 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import numpy as np 10 | import torch 11 | import torch.utils.checkpoint 12 | 13 | import jax 14 | import jax.numpy as jnp 15 | import optax 16 | import transformers 17 | from accelerate import Accelerator 18 | from datasets import load_dataset 19 | from diffusers import ( 20 | FlaxAutoencoderKL, 21 | FlaxDDPMScheduler, 22 | FlaxPNDMScheduler, 23 | FlaxStableDiffusionPipeline, 24 | FlaxUNet2DConditionModel, 25 | ) 26 | from diffusers.pipelines.stable_diffusion import FlaxStableDiffusionSafetyChecker 27 | from flax import jax_utils 28 | from flax.training import train_state 29 | from flax.training.common_utils import shard 30 | from huggingface_hub import HfFolder, Repository, whoami, HfApi 31 | from huggingface_hub.utils import RepositoryNotFoundError 32 | from torchvision import transforms 33 | from tqdm.auto import tqdm 34 | from transformers import CLIPFeatureExtractor, CLIPTokenizer, FlaxCLIPTextModel, set_seed 35 | 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | def parse_args(): 41 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 42 | parser.add_argument( 43 | "--pretrained_model_name_or_path", 44 | type=str, 45 | default=None, 46 | required=True, 47 | help="Path to pretrained model or model identifier from huggingface.co/models.", 48 | ) 49 | parser.add_argument( 50 | "--dataset_name", 51 | type=str, 52 | default=None, 53 | help=( 54 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 55 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 56 | " or to a folder containing files that 🤗 Datasets can understand." 57 | ), 58 | ) 59 | parser.add_argument( 60 | "--dataset_config_name", 61 | type=str, 62 | default=None, 63 | help="The config of the Dataset, leave as None if there's only one config.", 64 | ) 65 | parser.add_argument( 66 | "--train_data_dir", 67 | type=str, 68 | default=None, 69 | help=( 70 | "A folder containing the training data. Folder contents must follow the structure described in" 71 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 72 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 73 | ), 74 | ) 75 | parser.add_argument( 76 | "--image_column", type=str, default="image", help="The column of the dataset containing an image." 77 | ) 78 | parser.add_argument( 79 | "--caption_column", 80 | type=str, 81 | default="text", 82 | help="The column of the dataset containing a caption or a list of captions.", 83 | ) 84 | parser.add_argument( 85 | "--max_train_samples", 86 | type=int, 87 | default=None, 88 | help=( 89 | "For debugging purposes or quicker training, truncate the number of training examples to this " 90 | "value if set." 91 | ), 92 | ) 93 | parser.add_argument( 94 | "--output_dir", 95 | type=str, 96 | default="sd-model-finetuned", 97 | help="The output directory where the model predictions and checkpoints will be written.", 98 | ) 99 | parser.add_argument( 100 | "--cache_dir", 101 | type=str, 102 | default=None, 103 | help="The directory where the downloaded models and datasets will be stored.", 104 | ) 105 | parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.") 106 | parser.add_argument( 107 | "--resolution", 108 | type=int, 109 | default=512, 110 | help=( 111 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 112 | " resolution" 113 | ), 114 | ) 115 | parser.add_argument( 116 | "--center_crop", 117 | action="store_true", 118 | help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)", 119 | ) 120 | parser.add_argument( 121 | "--random_flip", 122 | action="store_true", 123 | help="whether to randomly flip images horizontally", 124 | ) 125 | parser.add_argument( 126 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 127 | ) 128 | parser.add_argument("--num_train_epochs", type=int, default=100) 129 | parser.add_argument( 130 | "--max_train_steps", 131 | type=int, 132 | default=None, 133 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 134 | ) 135 | parser.add_argument( 136 | "--learning_rate", 137 | type=float, 138 | default=1e-4, 139 | help="Initial learning rate (after the potential warmup period) to use.", 140 | ) 141 | parser.add_argument( 142 | "--scale_lr", 143 | action="store_true", 144 | default=False, 145 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 146 | ) 147 | parser.add_argument( 148 | "--lr_scheduler", 149 | type=str, 150 | default="constant", 151 | help=( 152 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 153 | ' "constant", "constant_with_warmup"]' 154 | ), 155 | ) 156 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 157 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 158 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 159 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 160 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 161 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 162 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 163 | parser.add_argument( 164 | "--hub_model_id", 165 | type=str, 166 | default=None, 167 | help="The name of the repository to keep in sync with the local `output_dir`.", 168 | ) 169 | parser.add_argument( 170 | "--logging_dir", 171 | type=str, 172 | default="logs", 173 | help=( 174 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 175 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 176 | ), 177 | ) 178 | parser.add_argument( 179 | "--report_to", 180 | type=str, 181 | default="tensorboard", 182 | help=( 183 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,' 184 | ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.' 185 | "Only applicable when `--with_tracking` is passed." 186 | ), 187 | ) 188 | parser.add_argument( 189 | "--tracker_project_name", 190 | type=str, 191 | default="text2image-fine-tune", 192 | help="Tracker project name.", 193 | ) 194 | parser.add_argument( 195 | "--mixed_precision", 196 | type=str, 197 | default="no", 198 | choices=["no", "fp16", "bf16"], 199 | help=( 200 | "Whether to use mixed precision. Choose" 201 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 202 | "and an Nvidia Ampere GPU." 203 | ), 204 | ) 205 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 206 | 207 | args = parser.parse_args() 208 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 209 | if env_local_rank != -1 and env_local_rank != args.local_rank: 210 | args.local_rank = env_local_rank 211 | 212 | # Sanity checks 213 | if args.dataset_name is None and args.train_data_dir is None: 214 | raise ValueError("Need either a dataset name or a training folder.") 215 | 216 | return args 217 | 218 | 219 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 220 | if token is None: 221 | token = HfFolder.get_token() 222 | if organization is None: 223 | username = whoami(token)["name"] 224 | return f"{username}/{model_id}" 225 | else: 226 | return f"{organization}/{model_id}" 227 | 228 | 229 | dataset_name_mapping = { 230 | "lambdalabs/pokemon-blip-captions": ("image", "text"), 231 | } 232 | 233 | 234 | def get_params_to_save(params): 235 | return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params)) 236 | 237 | 238 | def main(): 239 | args = parse_args() 240 | 241 | logging.basicConfig( 242 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 243 | datefmt="%m/%d/%Y %H:%M:%S", 244 | level=logging.INFO, 245 | ) 246 | # Setup logging, we only want one process per machine to log things on the screen. 247 | logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR) 248 | if jax.process_index() == 0: 249 | transformers.utils.logging.set_verbosity_info() 250 | else: 251 | transformers.utils.logging.set_verbosity_error() 252 | 253 | if args.seed is not None: 254 | set_seed(args.seed) 255 | 256 | # Handle the repository creation 257 | if jax.process_index() == 0: 258 | if args.push_to_hub: 259 | api = HfApi() 260 | if args.hub_model_id is None: 261 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 262 | else: 263 | repo_name = args.hub_model_id 264 | try: 265 | repo_exists = True 266 | model_info = api.model_info(repo_name, token=args.hub_token) 267 | logger.info(f" Model ID = {model_info.modelId}") 268 | except RepositoryNotFoundError: 269 | repo_exists = False 270 | finally: 271 | if repo_exists: 272 | repo = Repository(args.output_dir, clone_from=repo_name, use_auth_token=args.hub_token) 273 | else: 274 | api.create_repo(repo_name, private=True, token=args.hub_token) 275 | repo = Repository(args.output_dir, clone_from=repo_name, use_auth_token=args.hub_token) 276 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 277 | if "step_*" not in gitignore: 278 | gitignore.write("step_*\n") 279 | if "epoch_*" not in gitignore: 280 | gitignore.write("epoch_*\n") 281 | elif args.output_dir is not None: 282 | os.makedirs(args.output_dir, exist_ok=True) 283 | 284 | # Get the datasets: you can either provide your own training and evaluation files (see below) 285 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 286 | 287 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 288 | # download the dataset. 289 | if args.dataset_name is not None: 290 | # Downloading and loading a dataset from the hub. 291 | dataset = load_dataset( 292 | args.dataset_name, 293 | args.dataset_config_name, 294 | cache_dir=args.cache_dir, 295 | ) 296 | else: 297 | data_files = {} 298 | if args.train_data_dir is not None: 299 | data_files["train"] = os.path.join(args.train_data_dir, "**") 300 | dataset = load_dataset( 301 | "imagefolder", 302 | data_files=data_files, 303 | cache_dir=args.cache_dir, 304 | ) 305 | # See more about loading custom images at 306 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 307 | 308 | # Preprocessing the datasets. 309 | # We need to tokenize inputs and targets. 310 | column_names = dataset["train"].column_names 311 | 312 | # 6. Get the column names for input/target. 313 | dataset_columns = dataset_name_mapping.get(args.dataset_name, None) 314 | if args.image_column is None: 315 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 316 | else: 317 | image_column = args.image_column 318 | if image_column not in column_names: 319 | raise ValueError( 320 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" 321 | ) 322 | if args.caption_column is None: 323 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 324 | else: 325 | caption_column = args.caption_column 326 | if caption_column not in column_names: 327 | raise ValueError( 328 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" 329 | ) 330 | 331 | # Preprocessing the datasets. 332 | # We need to tokenize input captions and transform the images. 333 | def tokenize_captions(examples, is_train=True): 334 | captions = [] 335 | for caption in examples[caption_column]: 336 | if isinstance(caption, str): 337 | captions.append(caption) 338 | elif isinstance(caption, (list, np.ndarray)): 339 | # take a random caption if there are multiple 340 | captions.append(random.choice(caption) if is_train else caption[0]) 341 | else: 342 | raise ValueError( 343 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 344 | ) 345 | inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True) 346 | input_ids = inputs.input_ids 347 | return input_ids 348 | 349 | train_transforms = transforms.Compose( 350 | [ 351 | transforms.Resize((args.resolution, args.resolution), interpolation=transforms.InterpolationMode.BILINEAR), 352 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 353 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), 354 | transforms.ToTensor(), 355 | transforms.Normalize([0.5], [0.5]), 356 | ] 357 | ) 358 | 359 | def preprocess_train(examples): 360 | images = [image.convert("RGB") for image in examples[image_column]] 361 | examples["pixel_values"] = [train_transforms(image) for image in images] 362 | examples["input_ids"] = tokenize_captions(examples) 363 | 364 | return examples 365 | 366 | if jax.process_index() == 0: 367 | if args.max_train_samples is not None: 368 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 369 | # Set the training transforms 370 | train_dataset = dataset["train"].with_transform(preprocess_train) 371 | 372 | def collate_fn(examples): 373 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 374 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 375 | input_ids = [example["input_ids"] for example in examples] 376 | 377 | padded_tokens = tokenizer.pad( 378 | {"input_ids": input_ids}, padding="max_length", max_length=tokenizer.model_max_length, return_tensors="pt" 379 | ) 380 | batch = { 381 | "pixel_values": pixel_values, 382 | "input_ids": padded_tokens.input_ids, 383 | } 384 | batch = {k: v.numpy() for k, v in batch.items()} 385 | 386 | return batch 387 | 388 | total_train_batch_size = args.train_batch_size * jax.local_device_count() 389 | train_dataloader = torch.utils.data.DataLoader( 390 | train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=total_train_batch_size, drop_last=True 391 | ) 392 | 393 | weight_dtype = jnp.float32 394 | if args.mixed_precision == "fp16": 395 | weight_dtype = jnp.float16 396 | elif args.mixed_precision == "bf16": 397 | weight_dtype = jnp.bfloat16 398 | 399 | # Load models and create wrapper for stable diffusion 400 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 401 | text_encoder = FlaxCLIPTextModel.from_pretrained( 402 | args.pretrained_model_name_or_path, subfolder="text_encoder", dtype=weight_dtype 403 | ) 404 | vae, vae_params = FlaxAutoencoderKL.from_pretrained( 405 | args.pretrained_model_name_or_path, subfolder="vae", dtype=weight_dtype 406 | ) 407 | unet, unet_params = FlaxUNet2DConditionModel.from_pretrained( 408 | args.pretrained_model_name_or_path, subfolder="unet", dtype=weight_dtype 409 | ) 410 | 411 | # Optimization 412 | if args.scale_lr: 413 | args.learning_rate = args.learning_rate * total_train_batch_size 414 | 415 | constant_scheduler = optax.constant_schedule(args.learning_rate) 416 | 417 | adamw = optax.adamw( 418 | learning_rate=constant_scheduler, 419 | b1=args.adam_beta1, 420 | b2=args.adam_beta2, 421 | eps=args.adam_epsilon, 422 | weight_decay=args.adam_weight_decay, 423 | ) 424 | 425 | optimizer = optax.chain( 426 | optax.clip_by_global_norm(args.max_grad_norm), 427 | adamw, 428 | ) 429 | 430 | state = train_state.TrainState.create(apply_fn=unet.__call__, params=unet_params, tx=optimizer) 431 | 432 | noise_scheduler = FlaxDDPMScheduler( 433 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 434 | ) 435 | 436 | # Initialize our training 437 | rng = jax.random.PRNGKey(args.seed) 438 | train_rngs = jax.random.split(rng, jax.local_device_count()) 439 | 440 | def train_step(state, text_encoder_params, vae_params, batch, train_rng): 441 | dropout_rng, sample_rng, new_train_rng = jax.random.split(train_rng, 3) 442 | 443 | def compute_loss(params): 444 | # Convert images to latent space 445 | vae_outputs = vae.apply( 446 | {"params": vae_params}, batch["pixel_values"], deterministic=True, method=vae.encode 447 | ) 448 | latents = vae_outputs.latent_dist.sample(sample_rng) 449 | # (NHWC) -> (NCHW) 450 | latents = jnp.transpose(latents, (0, 3, 1, 2)) 451 | latents = latents * 0.18215 452 | 453 | # Sample noise that we'll add to the latents 454 | noise_rng, timestep_rng = jax.random.split(sample_rng) 455 | noise = jax.random.normal(noise_rng, latents.shape) 456 | # Sample a random timestep for each image 457 | bsz = latents.shape[0] 458 | timesteps = jax.random.randint( 459 | timestep_rng, 460 | (bsz,), 461 | 0, 462 | noise_scheduler.config.num_train_timesteps, 463 | ) 464 | 465 | # Add noise to the latents according to the noise magnitude at each timestep 466 | # (this is the forward diffusion process) 467 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 468 | 469 | # Get the text embedding for conditioning 470 | encoder_hidden_states = text_encoder( 471 | batch["input_ids"], 472 | params=text_encoder_params, 473 | train=False, 474 | )[0] 475 | 476 | # Predict the noise residual and compute loss 477 | unet_outputs = unet.apply({"params": params}, noisy_latents, timesteps, encoder_hidden_states, train=True) 478 | noise_pred = unet_outputs.sample 479 | loss = (noise - noise_pred) ** 2 480 | loss = loss.mean() 481 | 482 | return loss 483 | 484 | grad_fn = jax.value_and_grad(compute_loss) 485 | loss, grad = grad_fn(state.params) 486 | grad = jax.lax.pmean(grad, "batch") 487 | 488 | new_state = state.apply_gradients(grads=grad) 489 | 490 | metrics = {"loss": loss} 491 | metrics = jax.lax.pmean(metrics, axis_name="batch") 492 | 493 | return new_state, metrics, new_train_rng 494 | 495 | # Create parallel version of the train step 496 | p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,)) 497 | 498 | # Replicate the train state on each device 499 | state = jax_utils.replicate(state) 500 | text_encoder_params = jax_utils.replicate(text_encoder.params) 501 | vae_params = jax_utils.replicate(vae_params) 502 | 503 | # Train! 504 | num_update_steps_per_epoch = math.ceil(len(train_dataloader)) 505 | 506 | # Scheduler and math around the number of training steps. 507 | if args.max_train_steps is None: 508 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 509 | 510 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 511 | 512 | logger.info("***** Running training *****") 513 | logger.info(f" Num examples = {len(train_dataset)}") 514 | logger.info(f" Num Epochs = {args.num_train_epochs}") 515 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 516 | logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}") 517 | logger.info(f" Total optimization steps = {args.max_train_steps}") 518 | logger.info(f" Mixed Precision = {args.mixed_precision}") 519 | logger.info(f" Learning Rate = {args.learning_rate}") 520 | logger.info(f" Max Grad Norm = {args.max_grad_norm}") 521 | logger.info(f" Dataset Name = {args.dataset_name}") 522 | 523 | accelerator = Accelerator(log_with=args.report_to) 524 | tracker_config = { 525 | "train_dataset": len(train_dataset), 526 | "num_train_epochs": args.num_train_epochs, 527 | "train_batch_size": args.train_batch_size, 528 | "total_train_batch_size": total_train_batch_size, 529 | "max_train_steps": args.max_train_steps, 530 | "mixed_precision": args.mixed_precision, 531 | "learning_rate": args.learning_rate, 532 | "max_grad_norm": args.max_grad_norm, 533 | "dataset_name": args.dataset_name, 534 | } 535 | accelerator.init_trackers(args.tracker_project_name, config=tracker_config) 536 | 537 | global_step = 0 538 | 539 | epochs = tqdm(range(args.num_train_epochs), desc="Epoch ... ", position=0) 540 | for epoch in epochs: 541 | # ======================== Training ================================ 542 | 543 | train_metrics = [] 544 | 545 | steps_per_epoch = len(train_dataset) // total_train_batch_size 546 | train_step_progress_bar = tqdm(total=steps_per_epoch, desc="Training...", position=1, leave=False) 547 | # train 548 | for batch in train_dataloader: 549 | batch = shard(batch) 550 | state, train_metric, train_rngs = p_train_step(state, text_encoder_params, vae_params, batch, train_rngs) 551 | train_metrics.append(train_metric) 552 | 553 | train_step_progress_bar.update(1) 554 | 555 | global_step += 1 556 | if global_step >= args.max_train_steps: 557 | break 558 | 559 | train_metric = jax_utils.unreplicate(train_metric) 560 | 561 | train_step_progress_bar.close() 562 | epochs.write(f"Epoch... ({epoch + 1}/{args.num_train_epochs} | Loss: {train_metric['loss']})") 563 | accelerator.log({"training_loss": train_metric["loss"]}, step=epoch) 564 | accelerator.end_training() 565 | 566 | # Create the pipeline using using the trained modules and save it. 567 | if jax.process_index() == 0: 568 | scheduler = FlaxPNDMScheduler( 569 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True 570 | ) 571 | safety_checker = FlaxStableDiffusionSafetyChecker.from_pretrained( 572 | "CompVis/stable-diffusion-safety-checker", from_pt=True 573 | ) 574 | pipeline = FlaxStableDiffusionPipeline( 575 | text_encoder=text_encoder, 576 | vae=vae, 577 | unet=unet, 578 | tokenizer=tokenizer, 579 | scheduler=scheduler, 580 | safety_checker=safety_checker, 581 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), 582 | ) 583 | 584 | pipeline.save_pretrained( 585 | args.output_dir, 586 | params={ 587 | "text_encoder": get_params_to_save(text_encoder_params), 588 | "vae": get_params_to_save(vae_params), 589 | "unet": get_params_to_save(state.params), 590 | "safety_checker": safety_checker.params, 591 | }, 592 | ) 593 | 594 | if args.push_to_hub: 595 | repo.git_add(auto_lfs_track=True) 596 | repo.git_commit(commit_message="End of training") 597 | repo.git_push() 598 | 599 | 600 | if __name__ == "__main__": 601 | main() 602 | --------------------------------------------------------------------------------