├── LICENSE ├── Makefile ├── README.md ├── assets └── spright_good-1.png ├── eval ├── README.md └── gpt4 │ ├── README.md │ ├── collate_gpt4_results.py │ ├── create_eval_dataset.py │ ├── eval_with_gpt4.py │ ├── push_eval_dataset_to_hub.py │ └── requirements.txt ├── pyproject.toml └── training ├── README.md ├── requirements.txt ├── spright_t2i_multinode_example.sh └── train.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | 2 | check_dirs := . 3 | 4 | quality: 5 | ruff check $(check_dirs) 6 | ruff format --check $(check_dirs) 7 | 8 | style: 9 | ruff check $(check_dirs) --fix 10 | ruff format $(check_dirs) 11 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SPRIGHT 🖼️✨ 2 | 3 | Welcome to the official GitHub repository for our paper titled "Getting it Right: Improving Spatial Consistency in Text-to-Image Models". Our work introduces a simple approach to enhance spatial consistency in text-to-image diffusion models, alongside a high-quality dataset designed for this purpose. 4 | 5 | **_Getting it Right: Improving Spatial Consistency in Text-to-Image Models_** by Agneet Chatterjee$, Gabriela Ben Melech Stan$, Estelle Aflalo, Sayak Paul, Dhruba Ghosh, Tejas Gokhale, Ludwig Schmidt, Hannaneh Hajishirzi, Vasudev Lal, Chitta Baral, Yezhou Yang. 6 | 7 | $ denotes equal contribution. 8 | 9 |

10 | 🤗 Models & Datasets | 📃 Paper | 11 | ⚙️ Demo | 12 | 🎮 Project Website 13 |

14 | 15 | Update July 05, 2024: We got accepted to ECCV'24 🥳 16 | 17 | ## 📄 Abstract 18 | _One of the key shortcomings in current text-to-image (T2I) models is their inability to consistently generate images which faithfully follow the spatial relationships specified in the text prompt. In this paper, we offer a comprehensive investigation of this limitation, while also developing datasets and methods that achieve state-of-the-art performance. First, we find that current vision-language datasets do not represent spatial relationships well enough; to alleviate this bottleneck, we create SPRIGHT, the first spatially-focused, large scale dataset, by re-captioning 6 million images from 4 widely used vision datasets. Through a 3-fold evaluation and analysis pipeline, we find that SPRIGHT largely improves upon existing datasets in capturing spatial relationships. To demonstrate its efficacy, we leverage only \~0.25% of SPRIGHT and achieve a 22% improvement in generating spatially accurate images while also improving the FID and CMMD scores. Secondly, we find that training on images containing a large number of objects results in substantial improvements in spatial consistency. Notably, we attain state-of-the-art on T2I-CompBench with a spatial score of 0.2133, by fine-tuning on <500 images. Finally, through a set of controlled experiments and ablations, we document multiple findings that we believe will enhance the understanding of factors that affect spatial consistency in text-to-image models. We publicly release our dataset and 19 | model to foster further research in this area._ 20 | 21 | ## 📚 Contents 22 | - [Installation](#installation) 23 | - [Training](#training) 24 | - [Inference](#inference) 25 | - [The SPRIGHT Dataset](#the-spright-dataset) 26 | - [Eval](#evaluation) 27 | - [Citing](#citing) 28 | - [Acknowledgments](#ack) 29 | 30 | 31 | ## 💾 Installation 32 | 33 | Make sure you have CUDA and PyTorch set up. The PyTorch [official documentation](https://pytorch.org/) is the best place to refer to for that. Rest of the installation instructions are provided in the respective sections. 34 | 35 | If you have access to the Habana Gaudi accelerators, you can benefit from them as our training script supports them. 36 | 37 | 38 | ## 🔍 Training 39 | 40 | Refer to [`training/`](./training). 41 | 42 | 43 | ## 🌺 Inference 44 | 45 | ```python 46 | from diffusers import DiffusionPipeline 47 | import torch 48 | 49 | spright_id = "SPRIGHT-T2I/spright-t2i-sd2" 50 | pipe = DiffusionPipeline.from_pretrained(spright_id, torch_dtype=torch.float16).to("cuda") 51 | 52 | image = pipe("A horse above a pizza").images[0] 53 | image 54 | ``` 55 | 56 | You can also run [the demo](https://huggingface.co/spaces/SPRIGHT-T2I/SPRIGHT-T2I) locally: 57 | 58 | ```bash 59 | git clone https://huggingface.co/spaces/SPRIGHT-T2I/SPRIGHT-T2I 60 | cd SPRIGHT-T2I 61 | python app.py 62 | ``` 63 | 64 | Make sure `gradio` and other dependencies are installed in your environment. 65 | 66 | 67 | ## 🖼️ The SPRIGHT Dataset 68 | 69 | Refer to our [paper](https://arxiv.org/abs/2404.01197) and [the dataset page](https://huggingface.co/datasets/SPRIGHT-T2I/spright) for more details. Below are some examples from the SPRIGHT dataset: 70 | 71 |

72 | 73 |

74 | 75 | 76 | ## 📊 Evaluation 77 | 78 | In the [`eval/`](./eval) directory, we provide details about the various evaluation methods we use in our work . 79 | 80 | 81 | ## 📜 Citing 82 | 83 | ```bibtex 84 | @misc{chatterjee2024getting, 85 | title={Getting it Right: Improving Spatial Consistency in Text-to-Image Models}, 86 | author={Agneet Chatterjee and Gabriela Ben Melech Stan and Estelle Aflalo and Sayak Paul and Dhruba Ghosh and Tejas Gokhale and Ludwig Schmidt and Hannaneh Hajishirzi and Vasudev Lal and Chitta Baral and Yezhou Yang}, 87 | year={2024}, 88 | eprint={2404.01197}, 89 | archivePrefix={arXiv}, 90 | primaryClass={cs.CV} 91 | } 92 | ``` 93 | 94 | 95 | ## 🙏 Acknowledgments 96 | 97 | We thank Lucain Pouget for helping us in uploading the dataset to the Hugging Face Hub and the Hugging Face team for providing computing resources to host our demo. The authors acknowledge resources and support from the Research Computing facilities at Arizona State University. This work was supported by NSF RI grants \#1750082 and \#2132724. The views and opinions of the authors expressed herein do not necessarily state or reflect those of the funding agencies and employers. 98 | -------------------------------------------------------------------------------- /assets/spright_good-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SPRIGHT-T2I/SPRIGHT/1ce84339c7aceaf502a3d0959d612de8b8ac0171/assets/spright_good-1.png -------------------------------------------------------------------------------- /eval/README.md: -------------------------------------------------------------------------------- 1 | ## Quantitative evals 2 | 3 | Here we list out the evaluations we carried out in our work and their corresponding code implementations. 4 | 5 | | **Eval** | **Implementation** | 6 | |:------------:|:------------:| 7 | | GPT-4 (V) | [`gpt4/`](./gpt4/) | 8 | | FAITHScore | [Official implementation](https://github.com/bcdnlp/FAITHSCORE) | 9 | | VISOR | [Official implementation](https://github.com/microsoft/VISOR) | 10 | | T2I-CompBench | [Official implementation](https://github.com/Karine-Huang/T2I-CompBench) | 11 | | GenEval | [Official implementation](https://github.com/djghosh13/geneval) | 12 | | CMMD | [Official implementation](https://github.com/google-research/google-research/tree/master/cmmd) | 13 | | FID | [Implementation](https://github.com/mseitzer/pytorch-fid) | 14 | | CKA | [Implementation](https://github.com/AntixK/PyTorch-Model-Compare) | 15 | | Attention Maps | [Official implementation](https://github.com/yuval-alaluf/Attend-and-Excite) | 16 | -------------------------------------------------------------------------------- /eval/gpt4/README.md: -------------------------------------------------------------------------------- 1 | # Automated quality evaluation with GPT-4 2 | 3 | This directory provides scripts for performing an automated quality evaluation with GPT-4. So, you need an OpenAI API key (with at least $5 credits in it) to run the `eval_with_gpt4.py` script. 4 | 5 | ## Dataset 6 | 7 | The dataset consists of 100 images which can be found [here](https://hf.co/datasets/SPRIGHT-T2I/100-images-for-eval). It was created from the [SAM](https://segment-anything.com/) and [CC12M](https://github.com/google-research-datasets/conceptual-12m) datasets. 8 | 9 | To create the dataset run (make sure the dependencies are installed first): 10 | 11 | ```bash 12 | python create_eval_dataset.py 13 | python push_eval_dataset_to_hub.py 14 | ``` 15 | 16 | _(Run `huggingface-cli login` before running `python push_eval_dataset_to_hub.py`. You might have to change the `ds_id` in the script as well.)_ 17 | 18 | ## Evaluation 19 | 20 | ```bash 21 | python eval_with_gpt4.py 22 | ``` 23 | 24 | The script comes with limited support for handling rate-limiting issues. 25 | 26 | ## Collating GPT-4 results 27 | 28 | Once `python eval_with_gpt4.py` has been run it should produce JSON files prefixed with `gpt4_evals`. You can then run the following to collate the results and push the final dataset to the HF Hub for auditing: 29 | 30 | ```bash 31 | python collate_gpt4_results.py 32 | ``` 33 | -------------------------------------------------------------------------------- /eval/gpt4/collate_gpt4_results.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import json 3 | 4 | from datasets import Dataset, Features, Value, load_dataset 5 | from datasets import Image as ImageFeature 6 | 7 | 8 | def sort_file_paths(file_paths): 9 | # Extract the starting id more accurately and use it as the key for sorting 10 | sorted_paths = sorted(file_paths, key=lambda x: int(x.split("_")[2])) 11 | return sorted_paths 12 | 13 | 14 | def get_ratings_from_json(json_path): 15 | all_ratings = [] 16 | with open(json_path, "r") as f: 17 | json_dict = json.load(f) 18 | for i in range(len(json_dict)): 19 | all_ratings.append(json_dict[i]) 20 | return all_ratings 21 | 22 | 23 | all_jsons = sorted(glob.glob("*.json")) 24 | sorted_all_jsons = sort_file_paths(all_jsons) 25 | 26 | all_ratings = [] 27 | for json_path in sorted_all_jsons: 28 | try: 29 | all_ratings.extend(get_ratings_from_json(json_path)) 30 | except: 31 | print(json_path) 32 | 33 | eval_dataset = load_dataset("ASU-HF/100-images-for-eval", split="train") 34 | 35 | 36 | def generation_fn(): 37 | for i in range(len(eval_dataset)): 38 | yield { 39 | "image": eval_dataset[i]["image"], 40 | "spatial_caption": eval_dataset[i]["spatial_caption"], 41 | "gpt4_rating": all_ratings[i]["rating"], 42 | "gpt4_explanation": all_ratings[i]["explanation"], 43 | } 44 | 45 | 46 | ds = Dataset.from_generator( 47 | generation_fn, 48 | features=Features( 49 | image=ImageFeature(), 50 | spatial_caption=Value("string"), 51 | gpt4_rating=Value("int32"), 52 | gpt4_explanation=Value("string"), 53 | ), 54 | ) 55 | ds_id = "ASU-HF/gpt4-evaluation" 56 | ds.push_to_hub(ds_id) 57 | -------------------------------------------------------------------------------- /eval/gpt4/create_eval_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import shutil 5 | 6 | from tqdm.auto import tqdm 7 | 8 | 9 | random.seed(2024) 10 | 11 | 12 | JSON_PATHS = ["cc12m/spatial_prompts_cc_res768.jsonl", "sa/spatial_prompts_sa_res768.jsonl"] 13 | CUT_OFF_FOR_EACH = 50 14 | SUBSET_DIR = "eval" 15 | ROOT_PATH = "/home/jupyter/test-images-spatial/human_eval_subset" 16 | 17 | 18 | def copy_images(tuple_entries, subset): 19 | final_dict = {} 20 | for entry in tqdm(tuple_entries): 21 | image_name = entry[0].split("/")[-1] 22 | image_to_copy_from = os.path.join(ROOT_PATH, subset, "images", image_name) 23 | image_to_copy_to = os.path.join(ROOT_PATH, SUBSET_DIR) 24 | shutil.copy(image_to_copy_from, image_to_copy_to) 25 | final_dict[image_name] = entry[1] 26 | return final_dict 27 | 28 | 29 | # Load the JSON files. 30 | cc12m_entries = [] 31 | with open(JSON_PATHS[0], "rb") as json_list: 32 | for json_str in json_list: 33 | cc12m_entries.append(json.loads(json_str)) 34 | 35 | sa_entries = [] 36 | with open(JSON_PATHS[1], "rb") as json_list: 37 | for json_str in json_list: 38 | sa_entries.append(json.loads(json_str)) 39 | 40 | # Prepare tuples and shuffle them for random sampling. 41 | print(len(cc12m_entries), len(sa_entries)) 42 | cc12m_tuples = [(line["file_name"], line["spatial_caption"]) for line in cc12m_entries] 43 | sa_tuples = [(line["file_name"], line["spatial_caption"]) for line in sa_entries] 44 | filtered_cc12m_tuples = [ 45 | (line[0], line[1]) 46 | for line in cc12m_tuples 47 | if os.path.exists(os.path.join(ROOT_PATH, "cc12m", "images", line[0].split("/")[-1])) 48 | ] 49 | 50 | # Keep paths that exist. 51 | filtered_sa_tuples = [ 52 | (line[0], line[1]) 53 | for line in sa_tuples 54 | if os.path.exists(os.path.join(ROOT_PATH, "sa", "images", line[0].split("/")[-1])) 55 | ] 56 | print(len(filtered_cc12m_tuples), len(filtered_sa_tuples)) 57 | random.shuffle(filtered_cc12m_tuples) 58 | random.shuffle(filtered_sa_tuples) 59 | 60 | # Cut off for subsets. 61 | subset_cc12m_tuples = filtered_cc12m_tuples[:CUT_OFF_FOR_EACH] 62 | subset_sa_tuples = filtered_sa_tuples[:CUT_OFF_FOR_EACH] 63 | 64 | # Copy over the images. 65 | if not os.path.exists(SUBSET_DIR): 66 | os.makedirs(SUBSET_DIR, exist_ok=True) 67 | 68 | final_data_dict = {} 69 | cc12m_dict = copy_images(subset_cc12m_tuples, "cc12m") 70 | sa_dict = copy_images(subset_sa_tuples, "sa") 71 | print(len(cc12m_dict), len(sa_dict)) 72 | final_data_dict = {**cc12m_dict, **sa_dict} 73 | 74 | # Create a json file to record metadata. 75 | with open("final_data_dict.json", "w") as f: 76 | json.dump(final_data_dict, f) 77 | -------------------------------------------------------------------------------- /eval/gpt4/eval_with_gpt4.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import os 4 | import time 5 | from concurrent.futures import ThreadPoolExecutor 6 | from io import BytesIO 7 | 8 | import requests 9 | from datasets import load_dataset 10 | 11 | 12 | api_key = os.getenv("OPENAI_API_KEY") 13 | headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} 14 | 15 | 16 | def encode_image(image): 17 | buffered = BytesIO() 18 | image.save(buffered, format="JPEG") 19 | img_str = base64.b64encode(buffered.getvalue()) 20 | return img_str.decode("utf-8") 21 | 22 | 23 | def create_payload(image_string, caption): 24 | # System message adapted from DALL-E 3 (as found in the tech report): 25 | # https://cdn.openai.com/papers/dall-e-3.pdf 26 | messages = [ 27 | { 28 | "role": "system", 29 | "content": """You are part of a team of bots that evaluates images and their captions. Your job is to come up with a rating in between 1 to 10 to evaluate the provided caption for the provided image. While performing the assessment, consider the correctness of spatial relationships captured in the provided image. You should return the response formatted as a dictionary having two keys: 'rating', denoting the numeric rating and 'explanation', denoting a brief justification for the rating. 30 | 31 | The captions you are judging are designed to stress - test image captioning programs, and may include things such as: 32 | 1. Spatial phrases like above, below, left, right, front, behind, background, foreground (focus most on the correctness of these words ) 33 | 2. Relative sizes between objects such as small & large, big & tiny (focus on the correctness of these words) 34 | 3. Scrambled or mis - spelled words (the image generator should an image associated with the probably meaning). 35 | 36 | You need to make a decision as to whether or not the caption is correct, given the image. 37 | 38 | A few rules : 39 | 1. It is ok if the caption does not explicitly mention each object in the image; as long as the caption is correct in its entirety, it is fine. 40 | 2. It also ok if some captions dont have spatial relationships; judge them based on their correctness. A caption not containing spatial relationships should not be penalized. 41 | 3. You will think out loud about your eventual conclusion. Don't include your reasoning in the final output. 42 | 4. You should return the response formatted as a Python-formatted dictionary having 43 | two keys: 'rating', denoting the numeric rating and 'explanation', denoting 44 | a brief justification for the rating. 45 | """, 46 | } 47 | ] 48 | messages.append( 49 | { 50 | "role": "user", 51 | "content": [ 52 | { 53 | "type": "text", 54 | "text": "Come with a rating in between 1 to 10 to evaluate the provided caption for the provided image." 55 | " While performing the assessment, consider the correctness of spatial relationships captured in the provided image." 56 | f" Caption provided: {caption}", 57 | }, 58 | { 59 | "type": "image_url", 60 | "image_url": {"url": f"data:image/jpeg;base64,{image_string}"}, 61 | }, 62 | ], 63 | } 64 | ) 65 | 66 | payload = { 67 | "model": "gpt-4-vision-preview", 68 | "messages": messages, 69 | "max_tokens": 250, 70 | "seed": 2024, 71 | } 72 | return payload 73 | 74 | 75 | def get_response(image_string, caption): 76 | payload = create_payload(image_string, caption) 77 | response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload) 78 | return response.json() 79 | 80 | 81 | def get_rating(response): 82 | content = response["choices"][0]["message"]["content"] 83 | # The following clean-up is a bit of a shout in the void and the oblivion is inevitable. 84 | cleaned_content = ( 85 | content.strip().replace("```", "").replace("json", "").replace("python", "").strip().replace("\n", "") 86 | ) 87 | return cleaned_content 88 | 89 | 90 | dataset = load_dataset("ASU-HF/100-images-for-eval", split="train") 91 | image_strings = [] 92 | captions = [] 93 | for i in range(len(dataset)): 94 | image_strings.append(encode_image(dataset[i]["image"])) 95 | captions.append(dataset[i]["spatial_caption"]) 96 | 97 | 98 | chunk_size = 8 99 | json_retry = 4 100 | per_min_token_limit = 10000 101 | per_day_request_limit = 500 102 | total_requests_made = 0 103 | batch_total_tokens = 0 104 | 105 | with ThreadPoolExecutor(chunk_size) as e: 106 | for i in range(0, len(image_strings), chunk_size): 107 | responses = None 108 | cur_retry = 0 109 | 110 | # request handling with retries 111 | while responses is None and cur_retry <= json_retry: 112 | try: 113 | responses = list(e.map(get_response, image_strings[i : i + chunk_size], captions[i : i + chunk_size])) 114 | except Exception as e: 115 | cur_retry = cur_retry + 1 116 | continue 117 | 118 | # handle rate-limits 119 | total_requests_made += len(image_strings[i : i + chunk_size]) 120 | for response in responses: 121 | batch_total_tokens += response["usage"]["total_tokens"] 122 | 123 | with open(f"gpt4_evals_{i}_to_{(i + chunk_size) - 1}.json", "w") as f: 124 | ratings = [eval(get_rating(response)) for response in responses] 125 | json.dump(ratings, f, indent=4) 126 | 127 | if total_requests_made > per_day_request_limit: 128 | total_requests_made = 0 129 | time.sleep(86400) # wait a day! 130 | elif batch_total_tokens > per_min_token_limit: 131 | batch_total_tokens = 0 132 | time.sleep(1800) # wait for half an hour to prevent per_min_request_limit 133 | -------------------------------------------------------------------------------- /eval/gpt4/push_eval_dataset_to_hub.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from datasets import Dataset, Features, Value 5 | from datasets import Image as ImageFeature 6 | 7 | 8 | final_dict_path = "final_data_dict.json" 9 | with open(final_dict_path, "r") as f: 10 | final_dict = json.load(f) 11 | 12 | 13 | root_path = "/home/jupyter/test-images-spatial/human_eval_subset/eval" 14 | 15 | 16 | def generation_fn(): 17 | for k in final_dict: 18 | yield { 19 | "image": os.path.join(root_path, k), 20 | "spatial_caption": final_dict[k], 21 | "subset": "SA" if "sa" in k else "CC12M", 22 | } 23 | 24 | 25 | ds = Dataset.from_generator( 26 | generation_fn, 27 | features=Features( 28 | image=ImageFeature(), 29 | spatial_caption=Value("string"), 30 | subset=Value("string"), 31 | ), 32 | ) 33 | ds_id = "ASU-HF/100-images-for-eval" 34 | ds.push_to_hub(ds_id) 35 | -------------------------------------------------------------------------------- /eval/gpt4/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | openai 3 | Pillow 4 | tqdm -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | 2 | [tool.ruff] 3 | # Never enforce `E501` (line length violations). 4 | ignore = ["C901", "E501", "E741", "F402", "F823"] 5 | select = ["C", "E", "F", "I", "W"] 6 | line-length = 119 7 | 8 | # Ignore import violations in all `__init__.py` files. 9 | [tool.ruff.per-file-ignores] 10 | "__init__.py" = ["E402", "F401", "F403", "F811"] 11 | "src/diffusers/utils/dummy_*.py" = ["F401"] 12 | 13 | [tool.ruff.isort] 14 | lines-after-imports = 2 15 | known-first-party = ["diffusers"] 16 | 17 | [tool.ruff.format] 18 | # Like Black, use double quotes for strings. 19 | quote-style = "double" 20 | 21 | # Like Black, indent with spaces, rather than tabs. 22 | indent-style = "space" 23 | 24 | # Like Black, respect magic trailing commas. 25 | skip-magic-trailing-comma = false 26 | 27 | # Like Black, automatically detect the appropriate line ending. 28 | line-ending = "auto" 29 | -------------------------------------------------------------------------------- /training/README.md: -------------------------------------------------------------------------------- 1 | ## Training with the SPRIGHT dataset 2 | 3 | If you're on CUDA, then make sure it's properly set up and install PyTorch following instructions from its official documentation. 4 | 5 | If you've access Habana Gaudi accelerators and wish to use them for training then first get `habana` set up, following the [official website](https://docs.habana.ai/en/latest/Installation_Guide/index.html#gaudi-installation-guide). Then install `optimum`: 6 | 7 | ```bash 8 | pip install git+https://github.com/huggingface/optimum-habana.git 9 | ``` 10 | 11 | Other training-related Python dependencies are found in [`requirements.txt`](./requirements.txt). 12 | 13 | ### Data preparation 14 | 15 | In order to work on our dataset, 16 | 17 | - Download the dataset from [here](https://huggingface.co/datasets/SPRIGHT-T2I/spright) and place it under /path/to/spright 18 | - The structure of the downloaded repository is as followed: 19 | 20 | ```plaintext 21 | /path/to/spright/ 22 | ├── data/ 23 | │ └── *.tar 24 | ├── metadata.json 25 | ├── load_data.py 26 | └── robust_upload.py 27 | ``` 28 | - Each .tar file contains aounrd 10k images with associated general and spatial captions. 29 | - `metadata.json` contains the nature of the split for each tar file, as well as the number of samples per .tar file. 30 | 31 | ### Example training command 32 | #### Multiple GPUs 33 | 34 | 1. In order to finetune our model using the train and validation splits as set by [SPRIGHT data](https://github.com/SPRIGHT-T2I/SPRIGHT#data-preparation) in `metadata.json`: 35 | ```bash 36 | export MODEL_NAME="SPRIGHT-T2I/spright-t2i-sd2" 37 | export OUTDIR="path/to/outdir" 38 | export SPRIGHT_SPLIT="path/to/spright/metadata.json" # download from: https://huggingface.co/datasets/SPRIGHT-T2I/spright/blob/main/metadata.json 39 | 40 | accelerate launch --mixed_precision="fp16" train.py \ 41 | --pretrained_model_name_or_path=$MODEL_NAME \ 42 | --use_ema \ 43 | --resolution=768 --center_crop --random_flip \ 44 | --train_batch_size=4 \ 45 | --gradient_accumulation_steps=1 \ 46 | --max_train_steps=15000 \ 47 | --learning_rate=5e-05 \ 48 | --max_grad_norm=1 \ 49 | --lr_scheduler="constant" \ 50 | --lr_warmup_steps=0 \ 51 | --output_dir=$OUTDIR \ 52 | --validation_epochs 1 \ 53 | --checkpointing_steps=1500 \ 54 | --freeze_text_encoder_steps 0 \ 55 | --train_text_encoder \ 56 | --text_encoder_lr=1e-06 \ 57 | --spright_splits $SPRIGHT_SPLIT 58 | ``` 59 | 2. It is possible to set the train/val splits manually, by specifying the particular *.tar files using `--spright_train_costum` for training and `--spright_val_costum` for validation. `metadata.json` should also be passed to the training command, as it provides the count of samples in each .tar file: 60 | ```bash 61 | export MODEL_NAME="SPRIGHT-T2I/spright-t2i-sd2" 62 | export OUTDIR="path/to/outdir" 63 | export WEBDATA_TRAIN="path/to/spright/data/{00000..00004}.tar" 64 | export WEBDATA_VAL="path/to/spright/data/{00004..00005}.tar" 65 | export SPRIGHT_SPLIT="path/to/spright/metadata.json" # download from: https://huggingface.co/datasets/SPRIGHT-T2I/spright/blob/main/metadata.json 66 | 67 | accelerate launch --mixed_precision="fp16" train.py \ 68 | --pretrained_model_name_or_path=$MODEL_NAME \ 69 | --use_ema \ 70 | --resolution=768 --center_crop --random_flip \ 71 | --train_batch_size=4 \ 72 | --gradient_accumulation_steps=1 \ 73 | --max_train_steps=15000 \ 74 | --learning_rate=5e-05 \ 75 | --max_grad_norm=1 \ 76 | --lr_scheduler="constant" \ 77 | --lr_warmup_steps=0 \ 78 | --output_dir=$OUTDIR \ 79 | --validation_epochs 1 \ 80 | --checkpointing_steps=1500 \ 81 | --freeze_text_encoder_steps 0 \ 82 | --train_text_encoder \ 83 | --text_encoder_lr=1e-06 \ 84 | --spright_splits $SPRIGHT_SPLIT \ 85 | --spright_train_costum $WEBDATA_TRAIN \ 86 | --spright_val_costum $WEBDATA_VAL 87 | ``` 88 | To train the text encoder, set `--train_text_encoder`. The point at which text encoder training begins is determined by `--freeze_text_encoder_steps`, where 0 indicates that training for both the U-Net and text encoder starts simultaneously at the outset. It's possible to set different learning rates for the text encoder and the U-Net; these are configured through `--text_encoder_lr` for the text encoder and `--learning_rate` for the U-Net, respectively. 89 | 90 | ### Multiple Nodes 91 | In order to train on multiple nodes using SLURM, please refer to the [`spright_t2i_multinode_example.sh`](./spright_t2i_multinode_example.sh). 92 | 93 | ### Good to know 94 | 95 | Our training script supports experimentation tracking with Weights and Biases. If you wish to do so pass `--report="wandb"` in your training command. Make sure you install `wandb` before that. 96 | 97 | If you're on CUDA, you can push the training artifacts stored under `output_dir` to the Hugging Face Hub. Pass `--push_to_hub` if you wish to do so. You'd need to run `huggingface-cli login` before that. 98 | -------------------------------------------------------------------------------- /training/requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers 2 | transformers 3 | accelerate 4 | datasets 5 | webdataset 6 | torchvision -------------------------------------------------------------------------------- /training/spright_t2i_multinode_example.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | #SBATCH -p PP 4 | #SBATCH --gres=gpu:8 5 | #SBATCH --ntasks-per-node=1 6 | #SBATCH --cpus-per-task=80 7 | #SBATCH -N 4 8 | #SBATCH --job-name=spatial_finetuning_stable_diffusion 9 | 10 | 11 | conda activate env_name 12 | cd /path/to/training/script 13 | 14 | export MODEL_NAME="SPRIGHT-T2I/spright-t2i-sd2" 15 | export OUTDIR="/path/to/output/dir" 16 | export SPRIGHT_SPLIT="path/to/spright/metadata.json" 17 | 18 | ACCELERATE_CONFIG_FILE="$OUTDIR/${SLURM_JOB_ID}_accelerate_config.yaml.autogenerated" 19 | 20 | 21 | GPUS_PER_NODE=8 22 | NNODES=$SLURM_NNODES 23 | NUM_GPUS=$((GPUS_PER_NODE*SLURM_NNODES)) 24 | 25 | MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 26 | MASTER_PORT=25215 27 | 28 | # Auto-generate the accelerate config 29 | cat << EOT > $ACCELERATE_CONFIG_FILE 30 | compute_environment: LOCAL_MACHINE 31 | deepspeed_config: {} 32 | distributed_type: MULTI_GPU 33 | fsdp_config: {} 34 | machine_rank: 0 35 | main_process_ip: $MASTER_ADDR 36 | main_process_port: $MASTER_PORT 37 | main_training_function: main 38 | num_machines: $SLURM_NNODES 39 | num_processes: $NUM_GPUS 40 | use_cpu: false 41 | EOT 42 | 43 | # accelerate settings 44 | # Note: it is important to escape `$SLURM_PROCID` since we want the srun on each node to evaluate this variable 45 | export LAUNCHER="accelerate launch \ 46 | --rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,max_restarts=0,tee=3" \ 47 | --config_file $ACCELERATE_CONFIG_FILE \ 48 | --main_process_ip $MASTER_ADDR \ 49 | --main_process_port $MASTER_PORT \ 50 | --num_processes $NUM_GPUS \ 51 | --machine_rank \$SLURM_PROCID \ 52 | " 53 | 54 | # train 55 | PROGRAM="train.py \ 56 | --pretrained_model_name_or_path=$MODEL_NAME \ 57 | --use_ema \ 58 | --seed 42 \ 59 | --mixed_precision="fp16" \ 60 | --resolution=768 --center_crop --random_flip \ 61 | --train_batch_size=4 \ 62 | --gradient_accumulation_steps=1 \ 63 | --max_train_steps=15000 \ 64 | --learning_rate=5e-06 \ 65 | --max_grad_norm=1 \ 66 | --lr_scheduler="constant" \ 67 | --lr_warmup_steps=0 \ 68 | --output_dir=$OUTDIR \ 69 | --train_metadata_dir=$TRAIN_METADIR \ 70 | --dataloader=$DATA_LOADER \ 71 | --checkpointing_steps=1500 \ 72 | --freeze_text_encoder_steps=0 \ 73 | --train_text_encoder \ 74 | --text_encoder_lr=1e-06 \ 75 | --spright_splits $SPRIGHT_SPLIT 76 | " 77 | 78 | 79 | # srun error handling: 80 | # --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks 81 | # --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code 82 | SRUN_ARGS=" \ 83 | --wait=60 \ 84 | --kill-on-bad-exit=1 \ 85 | " 86 | 87 | export CMD="$LAUNCHER $PROGRAM" 88 | echo $CMD 89 | 90 | srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$CMD" 91 | 92 | -------------------------------------------------------------------------------- /training/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2024 SPRIGHT authors and The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | 18 | # added for gaudi 19 | import json 20 | import logging 21 | import math 22 | import os 23 | import random 24 | import shutil 25 | import time 26 | from pathlib import Path 27 | 28 | import accelerate 29 | import datasets 30 | import numpy as np 31 | import torch 32 | import torch.nn.functional as F 33 | import torch.utils.checkpoint 34 | import transformers 35 | from accelerate import Accelerator 36 | 37 | 38 | try: 39 | from optimum.habana import GaudiConfig 40 | from optimum.habana.accelerate import GaudiAccelerator 41 | except: 42 | GaudiConfig = None 43 | GaudiAccelerator = None 44 | 45 | from accelerate.logging import get_logger 46 | from accelerate.state import AcceleratorState 47 | from accelerate.utils import ProjectConfiguration 48 | 49 | 50 | try: 51 | from optimum.habana.utils import set_seed 52 | except: 53 | from accelerate.utils import set_seed 54 | import datetime 55 | 56 | from datasets import DownloadMode, load_dataset 57 | from huggingface_hub import create_repo, upload_folder 58 | from packaging import version 59 | from torchvision import transforms 60 | from tqdm.auto import tqdm 61 | from transformers import CLIPTextModel, CLIPTokenizer 62 | from transformers.utils import ContextManagers 63 | 64 | import diffusers 65 | from diffusers import AutoencoderKL, UNet2DConditionModel 66 | 67 | 68 | try: 69 | from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionPipeline 70 | except: 71 | from diffusers import DDPMScheduler, StableDiffusionPipeline 72 | from diffusers.optimization import get_scheduler 73 | from diffusers.training_utils import EMAModel, compute_snr 74 | from diffusers.utils import deprecate, is_wandb_available, make_image_grid 75 | 76 | 77 | try: 78 | # memory stats 79 | import habana_frameworks.torch as htorch 80 | import habana_frameworks.torch.core as htcore 81 | import habana_frameworks.torch.hpu as hthpu 82 | except: 83 | from diffusers.utils.import_utils import is_xformers_available 84 | htorch = None 85 | hthpu = None 86 | htcore = None 87 | import sys 88 | 89 | 90 | sys.path.append(os.path.dirname(os.getcwd())) 91 | import itertools 92 | import warnings 93 | 94 | import webdataset as wds 95 | from transformers import PretrainedConfig 96 | 97 | from diffusers.utils.torch_utils import is_compiled_module 98 | 99 | 100 | if is_wandb_available(): 101 | import wandb 102 | 103 | debug = False 104 | 105 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 106 | #check_min_version("0.23.0.dev0") 107 | 108 | logger = get_logger(__name__, log_level="INFO") 109 | 110 | 111 | def save_model_card( 112 | args, 113 | repo_id=None, 114 | images=None, 115 | train_text_encoder=False, 116 | repo_folder=None, 117 | ): 118 | img_str = "" 119 | if len(images) > 0: 120 | image_grid = make_image_grid(images, 1, len(args.validation_prompts)) 121 | image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png")) 122 | img_str += "![val_imgs_grid](./val_imgs_grid.png)\n" 123 | 124 | yaml = f""" 125 | --- 126 | license: creativeml-openrail-m 127 | base_model: {args.pretrained_model_name_or_path} 128 | datasets: 129 | - {args.dataset_name} 130 | tags: 131 | - stable-diffusion 132 | - stable-diffusion-diffusers 133 | - text-to-image 134 | - diffusers 135 | inference: true 136 | --- 137 | """ 138 | model_card = f""" 139 | # Text-to-image finetuning - {repo_id} 140 | Fine-tuning for the text encoder was enabled: {train_text_encoder}. 141 | 142 | This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n 143 | {img_str} 144 | 145 | ## Pipeline usage 146 | 147 | You can use the pipeline like so: 148 | 149 | ```python 150 | from diffusers import DiffusionPipeline 151 | import torch 152 | 153 | pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16) 154 | prompt = "{args.validation_prompts[0]}" 155 | image = pipeline(prompt).images[0] 156 | image.save("my_image.png") 157 | ``` 158 | 159 | ## Training info 160 | 161 | These are the key hyperparameters used during training: 162 | 163 | * Epochs: {args.num_train_epochs} 164 | * Learning rate: {args.learning_rate} 165 | * Batch size: {args.train_batch_size} 166 | * Gradient accumulation steps: {args.gradient_accumulation_steps} 167 | * Image resolution: {args.resolution} 168 | * Mixed-precision: {args.mixed_precision} 169 | 170 | """ 171 | wandb_info = "" 172 | if is_wandb_available(): 173 | wandb_run_url = None 174 | if wandb.run is not None: 175 | wandb_run_url = wandb.run.url 176 | 177 | if wandb_run_url is not None: 178 | wandb_info = f""" 179 | More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}). 180 | """ 181 | 182 | model_card += wandb_info 183 | 184 | with open(os.path.join(repo_folder, "README.md"), "w") as f: 185 | f.write(yaml + model_card) 186 | 187 | def compute_validation_loss(val_dataloader, vae, text_encoder, noise_scheduler, unet, args, weight_dtype): 188 | val_loss = 0 189 | num_steps= math.ceil(len(val_dataloader)) 190 | progress_bar = tqdm( 191 | range(0, num_steps), 192 | initial=0, 193 | desc="Steps", 194 | # Only show the progress bar once on each machine. 195 | disable=True, 196 | ) 197 | for step, batch in enumerate(val_dataloader): 198 | progress_bar.update(1) 199 | # Convert images to latent space 200 | latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() 201 | latents = latents * vae.config.scaling_factor 202 | 203 | # Sample noise that we'll add to the latents 204 | noise = torch.randn_like(latents) 205 | if args.noise_offset: 206 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise 207 | noise += args.noise_offset * torch.randn( 208 | (latents.shape[0], latents.shape[1], 1, 1), device=latents.device 209 | ) 210 | if args.input_perturbation: 211 | new_noise = noise + args.input_perturbation * torch.randn_like(noise) 212 | bsz = latents.shape[0] 213 | # Sample a random timestep for each image 214 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 215 | timesteps = timesteps.long() 216 | 217 | # Add noise to the latents according to the noise magnitude at each timestep 218 | # (this is the forward diffusion process) 219 | if args.input_perturbation: 220 | noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) 221 | else: 222 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 223 | 224 | # Get the text embedding for conditioning 225 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 226 | 227 | # Get the target for loss depending on the prediction type 228 | if args.prediction_type is not None: 229 | # set prediction_type of scheduler if defined 230 | noise_scheduler.register_to_config(prediction_type=args.prediction_type) 231 | 232 | if noise_scheduler.config.prediction_type == "epsilon": 233 | target = noise 234 | elif noise_scheduler.config.prediction_type == "v_prediction": 235 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 236 | else: 237 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 238 | 239 | # Predict the noise residual and compute loss 240 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 241 | 242 | if args.snr_gamma is None: 243 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 244 | else: 245 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. 246 | # Since we predict the noise instead of x_0, the original formulation is slightly changed. 247 | # This is discussed in Section 4.2 of the same paper. 248 | snr = compute_snr(noise_scheduler, timesteps) 249 | if noise_scheduler.config.prediction_type == "v_prediction": 250 | # Velocity objective requires that we add one to SNR values before we divide by them. 251 | snr = snr + 1 252 | mse_loss_weights = ( 253 | torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr 254 | ) 255 | 256 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 257 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights 258 | loss = loss.mean() 259 | 260 | logs = {"step_val_loss": loss.detach().item()} 261 | progress_bar.set_postfix(**logs) 262 | val_loss += loss.item() 263 | val_loss /= (step+1) 264 | return val_loss 265 | 266 | def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch, val_dataloader, noise_scheduler): 267 | logger.info("Running validation... ") 268 | 269 | 270 | if args.validation_prompts is not None: 271 | 272 | if args.device == "hpu": 273 | pipeline = GaudiStableDiffusionPipeline.from_pretrained( 274 | args.pretrained_model_name_or_path, 275 | text_encoder=accelerator.unwrap_model(text_encoder), 276 | tokenizer=tokenizer, 277 | vae=accelerator.unwrap_model(vae), 278 | unet=accelerator.unwrap_model(unet), 279 | safety_checker=None, 280 | revision=args.revision, 281 | use_habana=True, 282 | use_hpu_graphs=True, 283 | gaudi_config=args.gaudi_config_name, 284 | ) 285 | else: 286 | pipeline = StableDiffusionPipeline.from_pretrained( 287 | args.pretrained_model_name_or_path, 288 | vae=accelerator.unwrap_model(vae), 289 | text_encoder=accelerator.unwrap_model(text_encoder), 290 | tokenizer=tokenizer, 291 | unet=accelerator.unwrap_model(unet), 292 | safety_checker=None, 293 | revision=args.revision, 294 | torch_dtype=weight_dtype, 295 | ) 296 | pipeline = pipeline.to(accelerator.device) 297 | pipeline.set_progress_bar_config(disable=True) 298 | 299 | if args.enable_xformers_memory_efficient_attention: 300 | pipeline.enable_xformers_memory_efficient_attention() 301 | 302 | if args.seed is None: 303 | generator = None 304 | else: 305 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 306 | 307 | images = [] 308 | for i in range(len(args.validation_prompts)): 309 | if args.device == "hpu": 310 | image = pipeline(args.validation_prompts[i], num_inference_steps=50, generator=generator).images[0] 311 | else: 312 | with torch.autocast("cuda"): 313 | image = pipeline(args.validation_prompts[i], num_inference_steps=50, generator=generator).images[0] 314 | 315 | images.append(image) 316 | 317 | for tracker in accelerator.trackers: 318 | if tracker.name == "tensorboard": 319 | if args.validation_prompts is not None: 320 | np_images = np.stack([np.asarray(img) for img in images]) 321 | tracker.writer.add_images("validation/images", np_images, epoch, dataformats="NHWC") 322 | elif tracker.name == "wandb": 323 | if args.validation_prompts is not None: 324 | tracker.log( 325 | { 326 | "validation/images": [ 327 | wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}") 328 | for i, image in enumerate(images) 329 | ] 330 | } 331 | ) 332 | 333 | else: 334 | if args.device == "hpu": 335 | logger.warning(f"image logging not implemented for {tracker.name}") 336 | else: 337 | logger.warn(f"image logging not implemented for {tracker.name}") 338 | 339 | del pipeline 340 | if args.device != "hpu": 341 | torch.cuda.empty_cache() 342 | 343 | return images 344 | 345 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): 346 | text_encoder_config = PretrainedConfig.from_pretrained( 347 | pretrained_model_name_or_path, 348 | subfolder="text_encoder", 349 | revision=revision, 350 | ) 351 | model_class = text_encoder_config.architectures[0] 352 | 353 | if model_class == "CLIPTextModel": 354 | from transformers import CLIPTextModel 355 | 356 | return CLIPTextModel 357 | elif model_class == "RobertaSeriesModelWithTransformation": 358 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation 359 | 360 | return RobertaSeriesModelWithTransformation 361 | elif model_class == "T5EncoderModel": 362 | from transformers import T5EncoderModel 363 | 364 | return T5EncoderModel 365 | else: 366 | raise ValueError(f"{model_class} is not supported.") 367 | 368 | def parse_args(): 369 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 370 | parser.add_argument( 371 | "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." 372 | ) 373 | parser.add_argument( 374 | "--pretrained_model_name_or_path", 375 | type=str, 376 | default=None, 377 | required=True, 378 | help="Path to pretrained model or model identifier from huggingface.co/models.", 379 | ) 380 | parser.add_argument( 381 | "--revision", 382 | type=str, 383 | default=None, 384 | required=False, 385 | help="Revision of pretrained model identifier from huggingface.co/models.", 386 | ) 387 | parser.add_argument( 388 | "--dataset_name", 389 | type=str, 390 | default=None, 391 | help=( 392 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," 393 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 394 | " or to a folder containing files that 🤗 Datasets can understand." 395 | ), 396 | ) 397 | parser.add_argument( 398 | "--dataset_config_name", 399 | type=str, 400 | default=None, 401 | help="The config of the Dataset, leave as None if there's only one config.", 402 | ) 403 | parser.add_argument( 404 | "--train_data_dir", 405 | type=str, 406 | default=None, 407 | help=( 408 | "A folder containing the training data. Folder contents must follow the structure described in" 409 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 410 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 411 | ), 412 | ) 413 | parser.add_argument( 414 | "--spright_splits", 415 | type=str, 416 | default="split.json", 417 | help=( 418 | "A url containing the json file that defines the splits (https://huggingface.co/datasets/ASU-HF/spright/blob/main/split.json). The webdataset should contain the metadata as tar files." 419 | ), 420 | ) 421 | parser.add_argument( 422 | "--spright_train_costum", 423 | type=str, 424 | default=None, 425 | help=( 426 | "A url containing the webdataset train split. The webdataset should contain the metadata as tar files." 427 | ), 428 | ) 429 | parser.add_argument( 430 | "--spright_val_costum", 431 | type=str, 432 | default=None, 433 | help=( 434 | "A url containing the webdataset validation split. The webdataset should contain the metadata as tar files." 435 | ), 436 | ) 437 | parser.add_argument( 438 | "--webdataset_buffer_size", 439 | type=int, 440 | default=1000, 441 | help=( 442 | "buffer size of webdataset." 443 | ), 444 | ) 445 | parser.add_argument( 446 | "--dataset_size", 447 | type=float, 448 | default=None, 449 | help="dataset size to use. If set, the dataset will be truncated to this size.", 450 | ) 451 | parser.add_argument( 452 | "--val_split", 453 | type=float, 454 | default=0.1, 455 | help="ratio of validation size out of the entire dataset " 456 | ) 457 | parser.add_argument( 458 | "--train_metadata_dir", 459 | type=str, 460 | default=None, 461 | help=( 462 | "A folder containing subfolders: train, val, test with the metadata as jsonl files." 463 | " jsonl files provide the general and spatial captions for the images." 464 | ), 465 | ) 466 | parser.add_argument( 467 | "--dataloader", 468 | type=str, 469 | default=None, 470 | help=( 471 | "A python script with custom dataloader." 472 | ), 473 | ) 474 | parser.add_argument( 475 | "--image_column", type=str, default="image", help="The column of the dataset containing an image." 476 | ) 477 | parser.add_argument( 478 | "--caption_column", 479 | type=str, 480 | default="text", 481 | help="The column of the dataset containing a caption or a list of captions.", 482 | ) 483 | parser.add_argument( 484 | "--max_train_samples", 485 | type=int, 486 | default=None, 487 | help=( 488 | "For debugging purposes or quicker training, truncate the number of training examples to this " 489 | "value if set." 490 | ), 491 | ) 492 | parser.add_argument( 493 | "--validation_prompts", 494 | type=str, 495 | default=["The city is located behind the water, and the pier is relatively small in comparison to the expanse of the water and the city", "The bed is positioned in the center of the frame, with two red pillows on the left side", "The houses are located on the left side of the street, while the park is on the right side", "The spoon is located on the left side of the shelf, while the bowl is positioned in the center", "The room has a red carpet, and there is a chandelier hanging from the ceiling above the bed"], 496 | nargs="+", 497 | help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."), 498 | ) 499 | parser.add_argument( 500 | "--output_dir", 501 | type=str, 502 | default="", 503 | required=True, 504 | help="The output directory where the model predictions and checkpoints will be written.", 505 | ) 506 | parser.add_argument( 507 | "--cache_dir", 508 | type=str, 509 | default=None, 510 | help="The directory where the downloaded models and datasets will be stored.", 511 | ) 512 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 513 | parser.add_argument( 514 | "--resolution", 515 | type=int, 516 | default=512, 517 | help=( 518 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 519 | " resolution" 520 | ), 521 | ) 522 | parser.add_argument( 523 | "--pre_crop_resolution", 524 | type=int, 525 | default=768, 526 | help=( 527 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 528 | " resolution before being randomly cropped to the final `resolution`." 529 | ), 530 | ) 531 | parser.add_argument( 532 | "--center_crop", 533 | default=False, 534 | action="store_true", 535 | help=( 536 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 537 | " cropped. The images will be resized to the resolution first before cropping." 538 | ), 539 | ) 540 | parser.add_argument( 541 | "--random_flip", 542 | action="store_true", 543 | help="whether to randomly flip images horizontally", 544 | ) 545 | parser.add_argument( 546 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 547 | ) 548 | parser.add_argument("--num_train_epochs", type=int, default=100) 549 | parser.add_argument( 550 | "--max_train_steps", 551 | type=int, 552 | default=None, 553 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 554 | ) 555 | parser.add_argument( 556 | "--gradient_accumulation_steps", 557 | type=int, 558 | default=1, 559 | help="Number of updates steps to accumulate before performing a backward/update pass.", 560 | ) 561 | parser.add_argument( 562 | "--gradient_checkpointing", 563 | action="store_true", 564 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 565 | ) 566 | parser.add_argument( 567 | "--learning_rate", 568 | type=float, 569 | default=1e-4, 570 | help="Initial learning rate (after the potential warmup period) to use.", 571 | ) 572 | parser.add_argument( 573 | "--text_encoder_lr", 574 | type=float, 575 | default=None, 576 | help="Initial learning rate for the text encoder - should usually be samller than unet_lr(after the potential warmup period) to use. When set to None, it will be set to the same value as learning_rate.", 577 | ) 578 | parser.add_argument( 579 | "--scale_lr", 580 | action="store_true", 581 | default=False, 582 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 583 | ) 584 | parser.add_argument( 585 | "--lr_scheduler", 586 | type=str, 587 | default="constant", 588 | help=( 589 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 590 | ' "constant", "constant_with_warmup"]' 591 | ), 592 | ) 593 | parser.add_argument( 594 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 595 | ) 596 | parser.add_argument( 597 | "--snr_gamma", 598 | type=float, 599 | default=None, 600 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " 601 | "More details here: https://arxiv.org/abs/2303.09556.", 602 | ) 603 | parser.add_argument( 604 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 605 | ) 606 | parser.add_argument( 607 | "--allow_tf32", 608 | action="store_true", 609 | help=( 610 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 611 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 612 | ), 613 | ) 614 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") 615 | parser.add_argument( 616 | "--non_ema_revision", 617 | type=str, 618 | default=None, 619 | required=False, 620 | help=( 621 | "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" 622 | " remote repository specified with --pretrained_model_name_or_path." 623 | ), 624 | ) 625 | parser.add_argument( 626 | "--dataloader_num_workers", 627 | type=int, 628 | default=0, 629 | help=( 630 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 631 | ), 632 | ) 633 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 634 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 635 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 636 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 637 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 638 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 639 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 640 | parser.add_argument( 641 | "--prediction_type", 642 | type=str, 643 | default=None, 644 | help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.", 645 | ) 646 | parser.add_argument( 647 | "--hub_model_id", 648 | type=str, 649 | default=None, 650 | help="The name of the repository to keep in sync with the local `output_dir`.", 651 | ) 652 | parser.add_argument( 653 | "--logging_dir", 654 | type=str, 655 | default="logs", 656 | help=( 657 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 658 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 659 | ), 660 | ) 661 | parser.add_argument( 662 | "--mixed_precision", 663 | type=str, 664 | default=None, 665 | choices=["no", "fp16", "bf16"], 666 | help=( 667 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 668 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 669 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 670 | ), 671 | ) 672 | parser.add_argument( 673 | "--report_to", 674 | type=str, 675 | default="tensorboard", 676 | help=( 677 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 678 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 679 | ), 680 | ) 681 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 682 | parser.add_argument( 683 | "--checkpointing_steps", 684 | type=int, 685 | default=500, 686 | help=( 687 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 688 | " training using `--resume_from_checkpoint`." 689 | ), 690 | ) 691 | parser.add_argument( 692 | "--checkpoints_total_limit", 693 | type=int, 694 | default=None, 695 | help=("Max number of checkpoints to store."), 696 | ) 697 | parser.add_argument( 698 | "--resume_from_checkpoint", 699 | type=str, 700 | default=None, 701 | help=( 702 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 703 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 704 | ), 705 | ) 706 | parser.add_argument( 707 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 708 | ) 709 | parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.") 710 | parser.add_argument( 711 | "--validation_epochs", 712 | type=int, 713 | default=5, 714 | help="Run validation every X epochs.", 715 | ) 716 | parser.add_argument( 717 | "--tracker_project_name", 718 | type=str, 719 | default="text2image-fine-tune", 720 | help=( 721 | "The `project_name` argument passed to Accelerator.init_trackers for" 722 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" 723 | ), 724 | ) 725 | parser.add_argument( 726 | "--gaudi_config_name", 727 | type=str, 728 | default=None, 729 | help="Local path to the Gaudi configuration file or its name on the Hugging Face Hub.", 730 | ) 731 | parser.add_argument( 732 | "--throughput_warmup_steps", 733 | type=int, 734 | default=0, 735 | help=( 736 | "Number of steps to ignore for throughput calculation. For example, with throughput_warmup_steps=N, the" 737 | " first N steps will not be considered in the calculation of the throughput. This is especially useful in" 738 | " lazy mode." 739 | ), 740 | ) 741 | parser.add_argument( 742 | "--bf16", 743 | action="store_true", 744 | default=False, 745 | help=("Whether to use bf16 mixed precision."), 746 | ) 747 | parser.add_argument( 748 | "--device", 749 | type=str, 750 | default=None, 751 | help=("hpu, cuda or cpu."), 752 | ) 753 | parser.add_argument( 754 | "--train_text_encoder", 755 | action="store_true", 756 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", 757 | ) 758 | parser.add_argument( 759 | "--tokenizer_name", 760 | type=str, 761 | default=None, 762 | help="Pretrained tokenizer name or path if not the same as model_name", 763 | ) 764 | parser.add_argument( 765 | "--freeze_text_encoder_steps", 766 | type=int, 767 | default=0, 768 | help="Start text_encoder training after freeze_text_encoder_steps steps.", 769 | ) 770 | parser.add_argument("--comment", type=str, default="used long sentences generated by llava - spatial and general.", help="Comment that should appear in the run config") 771 | parser.add_argument("--git_token", type=str, default=None, help="If provided will enable to save the git sha to replicate") 772 | parser.add_argument("--general_caption", type=str, default="original_caption", choices = ["coca_caption", "original_caption"], 773 | help="Original are the oned from the original dataset, coca_caption is the one generated by COCA" \ 774 | "in case original is chosen, the original caption will be preffered as general_caption, if it does not exist than the general caption will be the coca caption") 775 | parser.add_argument("--spatial_caption_type", type=str, default="long", choices = ["short", "long", "short_negative"], help="Wheter to use long or short spatial captions") 776 | parser.add_argument( 777 | "--spatial_percent", 778 | type=float, 779 | default=50.0, 780 | help="approximately precentage of the time that spatial captions is chosen.", 781 | ) 782 | 783 | 784 | args = parser.parse_args() 785 | if args.resume_from_checkpoint is None: 786 | args.output_dir = os.path.join(args.output_dir , f"run_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}") 787 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 788 | if env_local_rank != -1 and env_local_rank != args.local_rank: 789 | args.local_rank = env_local_rank 790 | 791 | # default to using the same revision for the non-ema model if not specified 792 | if args.non_ema_revision is None: 793 | args.non_ema_revision = args.revision 794 | 795 | return args 796 | 797 | 798 | def main(): 799 | args = parse_args() 800 | 801 | url_train = None 802 | 803 | if args.spright_splits is not None and args.spright_train_costum is not None: 804 | warnings.warn("You can not specify the splits by both spright_splits and spright_train_costum." \ 805 | "The costum split will be used. If you want to use the SPRIGHT splits, remove the spright_train_costum argument.") 806 | 807 | 808 | if args.report_to == "wandb" and args.hub_token is not None: 809 | raise ValueError( 810 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." 811 | " Please use `huggingface-cli login` to authenticate with the Hub." 812 | ) 813 | 814 | # set device 815 | if hthpu and hthpu.is_available(): 816 | args.device = "hpu" 817 | logger.info("Using HPU") 818 | elif torch.cuda.is_available(): 819 | logger.info.device = "cuda" 820 | print("Using GPU") 821 | else: 822 | args.device = "cpu" 823 | logger.info("Using CPU") 824 | 825 | # set precision: 826 | if args.device == "hpu": 827 | if args.mixed_precision == "bf16": 828 | args.bf16 = True 829 | else: 830 | args.bf16 = False 831 | 832 | # set args for gaudi: 833 | assert not args.enable_xformers_memory_efficient_attention, "xformers is not supported on gaudi" 834 | assert not args.allow_tf32, "tf32 is not supported on gaudi" 835 | assert not args.gradient_checkpointing, "gradient_checkpointing is not supported on gaudi locally" 836 | assert not args.push_to_hub, "push_to_hub is not supported on gaudi locally" 837 | 838 | else: 839 | assert args.gaudi_config_name is None, "gaudi_config_name is only supported on gaudi" 840 | assert args.throughput_warmup_steps == 0, "throughput_warmup_steps is only supported on gaudi" 841 | 842 | 843 | if args.non_ema_revision is not None: 844 | deprecate( 845 | "non_ema_revision!=None", 846 | "0.15.0", 847 | message=( 848 | "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to" 849 | " use `--variant=non_ema` instead." 850 | ), 851 | ) 852 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 853 | 854 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 855 | 856 | if args.device == "hpu": 857 | gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name) 858 | if args.use_8bit_adam: 859 | gaudi_config.use_fused_adam = True 860 | args.use_8bit_adam = False 861 | 862 | accelerator = GaudiAccelerator( 863 | gradient_accumulation_steps=args.gradient_accumulation_steps, 864 | mixed_precision="bf16" if gaudi_config.use_torch_autocast or args.bf16 else "no", 865 | log_with=args.report_to, 866 | project_config=accelerator_project_config, 867 | force_autocast=gaudi_config.use_torch_autocast or args.bf16, 868 | ) 869 | else: 870 | accelerator = Accelerator( 871 | gradient_accumulation_steps=args.gradient_accumulation_steps, 872 | mixed_precision=args.mixed_precision, 873 | log_with=args.report_to, 874 | project_config=accelerator_project_config, 875 | ) 876 | 877 | # Make one log on every process with the configuration for debugging. 878 | logging.basicConfig( 879 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 880 | datefmt="%m/%d/%Y %H:%M:%S", 881 | level=logging.INFO, 882 | ) 883 | logger.info(accelerator.state, main_process_only=False) 884 | if accelerator.is_local_main_process: 885 | datasets.utils.logging.set_verbosity_warning() 886 | transformers.utils.logging.set_verbosity_warning() 887 | diffusers.utils.logging.set_verbosity_info() 888 | else: 889 | datasets.utils.logging.set_verbosity_error() 890 | transformers.utils.logging.set_verbosity_error() 891 | diffusers.utils.logging.set_verbosity_error() 892 | 893 | # If passed along, set the training seed now. 894 | if args.seed is not None: 895 | set_seed(args.seed) 896 | 897 | # Handle the repository creation 898 | if accelerator.is_main_process: 899 | if args.output_dir is not None: 900 | os.makedirs(args.output_dir, exist_ok=True) 901 | 902 | if args.push_to_hub: 903 | repo_id = create_repo( 904 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 905 | ).repo_id 906 | 907 | # Load scheduler, tokenizer and models. 908 | if args.device == "hpu": 909 | noise_scheduler = GaudiDDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 910 | else: 911 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 912 | 913 | tokenizer = CLIPTokenizer.from_pretrained( 914 | args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision 915 | ) 916 | 917 | def deepspeed_zero_init_disabled_context_manager(): 918 | """ 919 | returns either a context list that includes one that will disable zero.Init or an empty context list 920 | """ 921 | deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None 922 | if deepspeed_plugin is None: 923 | return [] 924 | 925 | return [deepspeed_plugin.zero3_init_context_manager(enable=False)] 926 | 927 | # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. 928 | # For this to work properly all models must be run through `accelerate.prepare`. But accelerate 929 | # will try to assign the same optimizer with the same weights to all models during 930 | # `deepspeed.initialize`, which of course doesn't work. 931 | # 932 | # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 933 | # frozen models from being partitioned during `zero.Init` which gets called during 934 | # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding 935 | # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. 936 | if not args.train_text_encoder: 937 | with ContextManagers(deepspeed_zero_init_disabled_context_manager()): 938 | text_encoder = CLIPTextModel.from_pretrained( 939 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 940 | ).to(accelerator.device) 941 | vae = AutoencoderKL.from_pretrained( 942 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision 943 | ) 944 | else: 945 | # import correct text encoder class 946 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) 947 | text_encoder = text_encoder_cls.from_pretrained( 948 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision) 949 | vae = AutoencoderKL.from_pretrained( 950 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision 951 | ) 952 | 953 | 954 | unet = UNet2DConditionModel.from_pretrained( 955 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision 956 | ) 957 | 958 | # Freeze vae and text_encoder and set unet to trainable 959 | vae.requires_grad_(False) 960 | if not args.train_text_encoder: 961 | text_encoder.requires_grad_(False) 962 | unet.train() 963 | 964 | # Create EMA for the unet. 965 | if args.use_ema: 966 | ema_unet = UNet2DConditionModel.from_pretrained( 967 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision 968 | ) 969 | ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config) 970 | 971 | if args.enable_xformers_memory_efficient_attention: 972 | if is_xformers_available(): 973 | import xformers 974 | 975 | xformers_version = version.parse(xformers.__version__) 976 | if xformers_version == version.parse("0.0.16"): 977 | logger.warn( 978 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 979 | ) 980 | unet.enable_xformers_memory_efficient_attention() 981 | else: 982 | raise ValueError("xformers is not available. Make sure it is installed correctly") 983 | 984 | def unwrap_model(model): 985 | model = accelerator.unwrap_model(model) 986 | model = model._orig_mod if is_compiled_module(model) else model 987 | return model 988 | 989 | # `accelerate` 0.16.0 will have better support for customized saving 990 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 991 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 992 | def save_model_hook(models, weights, output_dir): 993 | if accelerator.is_main_process: 994 | if args.use_ema: 995 | ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema")) 996 | 997 | for model in models: 998 | sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder" 999 | model.save_pretrained(os.path.join(output_dir, sub_dir)) 1000 | 1001 | # make sure to pop weight so that corresponding model is not saved again 1002 | weights.pop() 1003 | 1004 | def load_model_hook(models, input_dir): 1005 | if args.use_ema: 1006 | load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) 1007 | ema_unet.load_state_dict(load_model.state_dict()) 1008 | ema_unet.to(accelerator.device) 1009 | del load_model 1010 | 1011 | for i in range(len(models)): 1012 | # pop models so that they are not loaded again 1013 | model = models.pop() 1014 | 1015 | if isinstance(model, type(unwrap_model(text_encoder))): 1016 | # load transformers style into model 1017 | load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") 1018 | model.config = load_model.config 1019 | else: 1020 | # load diffusers style into model 1021 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") 1022 | model.register_to_config(**load_model.config) 1023 | 1024 | model.load_state_dict(load_model.state_dict()) 1025 | del load_model 1026 | 1027 | accelerator.register_save_state_pre_hook(save_model_hook) 1028 | accelerator.register_load_state_pre_hook(load_model_hook) 1029 | 1030 | if args.gradient_checkpointing: 1031 | unet.enable_gradient_checkpointing() 1032 | if args.train_text_encoder: 1033 | text_encoder.gradient_checkpointing_enable() 1034 | 1035 | # Check that all trainable models are in full precision 1036 | low_precision_error_string = ( 1037 | "Please make sure to always have all model weights in full float32 precision when starting training - even if" 1038 | " doing mixed precision training. copy of the weights should still be float32." 1039 | ) 1040 | 1041 | if unwrap_model(unet).dtype != torch.float32: 1042 | raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}") 1043 | 1044 | if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: 1045 | raise ValueError( 1046 | f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" 1047 | ) 1048 | 1049 | # Enable TF32 for faster training on Ampere GPUs, 1050 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 1051 | if args.allow_tf32: 1052 | torch.backends.cuda.matmul.allow_tf32 = True 1053 | 1054 | if args.scale_lr: 1055 | args.learning_rate = ( 1056 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 1057 | ) 1058 | 1059 | # Initialize the optimizer 1060 | if args.use_8bit_adam: 1061 | try: 1062 | import bitsandbytes as bnb 1063 | except ImportError: 1064 | raise ImportError( 1065 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 1066 | ) 1067 | 1068 | optimizer_cls = bnb.optim.AdamW8bit 1069 | elif args.device == "hpu" and gaudi_config.use_fused_adam: 1070 | from habana_frameworks.torch.hpex.optimizers import FusedAdamW 1071 | 1072 | optimizer_cls = FusedAdamW 1073 | else: 1074 | optimizer_cls = torch.optim.AdamW 1075 | # setting diffetent lr for text ancoder and unet: 1076 | if args.train_text_encoder: 1077 | unet_parameters_with_lr = {"params": unet.parameters(), "lr": args.learning_rate} 1078 | text_encoder_params_with_lr = { 1079 | "params": text_encoder.parameters(), 1080 | "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate 1081 | } 1082 | params_to_optimize = [ 1083 | unet_parameters_with_lr, 1084 | text_encoder_params_with_lr 1085 | ] 1086 | print(f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:" 1087 | f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. ") 1088 | 1089 | 1090 | else: 1091 | params_to_optimize = unet_parameters_with_lr 1092 | optimizer = optimizer_cls( 1093 | params_to_optimize, 1094 | betas=(args.adam_beta1, args.adam_beta2), 1095 | weight_decay=args.adam_weight_decay, 1096 | eps=args.adam_epsilon, 1097 | ) 1098 | 1099 | 1100 | 1101 | # Get the datasets: you can either provide your own training and evaluation files (see below) 1102 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 1103 | 1104 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 1105 | # download the dataset. 1106 | if args.dataset_name is not None: 1107 | # Downloading and loading a dataset from the hub. 1108 | dataset = load_dataset( 1109 | args.dataset_name, 1110 | args.dataset_config_name, 1111 | cache_dir=args.cache_dir, 1112 | data_dir=args.train_data_dir, 1113 | download_mode=DownloadMode.FORCE_REDOWNLOAD 1114 | ) 1115 | elif args.train_data_dir is not None: 1116 | data_files = {} 1117 | data_files["train"] = os.path.join(args.train_data_dir, "**") 1118 | dataset = load_dataset( 1119 | "imagefolder", 1120 | data_files=data_files, 1121 | cache_dir=args.cache_dir, 1122 | ) 1123 | # See more about loading custom images at 1124 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 1125 | elif args.dataloader is not None: 1126 | dataset = load_dataset( 1127 | args.dataloader, 1128 | data_dir=args.train_metadata_dir, 1129 | cache_dir=args.cache_dir 1130 | ) 1131 | elif args.spright_splits is not None: 1132 | # load the jaon files and read the train and val splits 1133 | if args.spright_splits.endswith(".json"): 1134 | with open(args.spright_splits, 'r') as f: 1135 | data = json.load(f) 1136 | # Filter the entries where split is 'train' 1137 | train_files = [] 1138 | train_len = 0 1139 | val_files = [] 1140 | val_len = 0 1141 | test_files = [] 1142 | test_len = 0 1143 | if args.spright_train_costum is not None: 1144 | url_train = args.spright_train_costum 1145 | start, end = map(int, url_train.split('/')[-1].strip('.tar{}').split('..')) 1146 | # Filter the data for the files in the range and sum the sizes 1147 | train_len = sum(item['size'] for item in data if start <= int(item['file'].strip('.tar')) <= end) 1148 | if args.spright_val_costum is not None: 1149 | url_val = args.spright_val_costum 1150 | start_val, end_val = map(int, url_train.split('/')[-1].strip('.tar{}').split('..')) 1151 | val_len = sum(item['size'] for item in data if start_val <= int(item['file'].strip('.tar')) <= end_val) 1152 | else: 1153 | url_val = None 1154 | else: 1155 | for item in data: 1156 | if item['split'] == 'train': 1157 | train_files.append(item['file'].split(".")[0]) 1158 | train_len += item['size'] 1159 | elif item['split'] == 'val': 1160 | val_files.append(item['file'].split(".")[0]) 1161 | val_len += item['size'] 1162 | elif item['split'] == 'test': 1163 | test_files.append(item['file'].split(".")[0]) 1164 | test_len += item['size'] 1165 | ext = data[0]['file'].split(".")[-1] 1166 | # Construct the url_train string 1167 | if len(train_files) != 0: 1168 | url_train = "/export/share/projects/mcai/spatial_data/spright/data/{" + ','.join(train_files) + "}" + f".{ext}" 1169 | if len(val_files) != 0: 1170 | url_val = "/export/share/projects/mcai/spatial_data/spright/data/{" + ','.join(val_files) + "}" + f".{ext}" 1171 | else: 1172 | url_val = None 1173 | else: 1174 | raise ValueError("'webdataset' should be a json file containing the train and val splits and there sizes") 1175 | 1176 | # Preprocessing the datasets. 1177 | # We need to tokenize inputs and targets. 1178 | # column_names = dataset["train"].column_names 1179 | # 6. Get the column names for input/target. 1180 | if args.spatial_caption_type == "short": 1181 | spatial_caption = 'short_spatial_caption' 1182 | elif args.spatial_caption_type == "long": 1183 | spatial_caption = 'spatial_caption' 1184 | elif args.spatial_caption_type == "short_negative": 1185 | spatial_caption = 'short_spatial_caption_negation' 1186 | 1187 | # Preprocessing the datasets. 1188 | # We need to tokenize input captions and transform the images. 1189 | 1190 | def get_random_sentence(caption): 1191 | if isinstance(caption, list): 1192 | caption = caption[0] 1193 | sentences_split = caption.split(".") 1194 | first_sentence = sentences_split[0] 1195 | sentences_split = sentences_split[1:] 1196 | if len(sentences_split)>1: 1197 | random_sentence = random.choice(sentences_split) 1198 | if len(random_sentence) == 0: 1199 | random_sentence = sentences_split[0] 1200 | else: 1201 | random_sentence = first_sentence 1202 | return random_sentence 1203 | 1204 | 1205 | def tokenize_captions(examples, is_train=True): 1206 | # check if its a webdataset 1207 | if url_train is not None: 1208 | examples = examples["captions"] 1209 | 1210 | column_lists = [args.general_caption, 'spatial_caption'] 1211 | 1212 | caption_column = random.choices( 1213 | column_lists, 1214 | weights=[100-args.spatial_percent, args.spatial_percent], 1215 | k=1 1216 | )[0] 1217 | 1218 | # check if the caption exists 1219 | if args.general_caption=="original_caption" and caption_column==args.general_caption: 1220 | if caption_column not in examples.keys(): 1221 | caption_column = "coca_caption" 1222 | elif args.spatial_caption_type == "short" and caption_column==spatial_caption: 1223 | examples['short_spatial_caption'] = get_random_sentence(examples[caption_column]) 1224 | caption_column = 'short_spatial_caption' 1225 | 1226 | captions = [] 1227 | # check if the caption is a list 1228 | if not isinstance(examples[caption_column], list): 1229 | examples[caption_column] = [examples[caption_column]] 1230 | 1231 | for caption in examples[caption_column]: 1232 | if isinstance(caption, str): 1233 | captions.append(caption) 1234 | if debug: 1235 | with open(os.path.join(args.output_dir, f"selected_captions{args.spatial_percent}.txt"), "a") as f: 1236 | f.write(f"{caption}\n") 1237 | elif isinstance(caption, (list, np.ndarray)): 1238 | # take a random caption if there are multiple 1239 | captions.append(random.choice(caption) if is_train else caption[0]) 1240 | else: 1241 | raise ValueError( 1242 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 1243 | ) 1244 | inputs = tokenizer( 1245 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 1246 | ) 1247 | return inputs.input_ids 1248 | 1249 | # Preprocessing the datasets. 1250 | if args.pre_crop_resolution and args.pre_crop_resolution > args.resolution: 1251 | train_transforms = transforms.Compose( 1252 | [ 1253 | transforms.Resize(args.pre_crop_resolution, interpolation=transforms.InterpolationMode.BILINEAR), 1254 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 1255 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 1256 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), 1257 | transforms.ToTensor(), 1258 | transforms.Normalize([0.5], [0.5]), 1259 | ] 1260 | ) 1261 | else: 1262 | train_transforms = transforms.Compose( 1263 | [ 1264 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 1265 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 1266 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), 1267 | transforms.ToTensor(), 1268 | transforms.Normalize([0.5], [0.5]), 1269 | ] 1270 | ) 1271 | 1272 | def preprocess_train(examples): 1273 | if url_train is not None: 1274 | examples["image"] = [examples["image"]] 1275 | images = [image.convert("RGB") for image in examples["image"]] 1276 | examples["pixel_values"] = [train_transforms(image) for image in images] 1277 | examples["input_ids"] = tokenize_captions(examples) 1278 | return examples 1279 | 1280 | def preprocess_train_back(examples): 1281 | images = [image.convert("RGB") for image in examples["image"]] 1282 | examples["pixel_values"] = [train_transforms(image) for image in images] 1283 | examples["input_ids"] = tokenize_captions(examples) 1284 | return examples 1285 | 1286 | def collate_fn(examples): 1287 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 1288 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 1289 | input_ids = torch.stack([example["input_ids"] for example in examples]) 1290 | return {"pixel_values": pixel_values, "input_ids": input_ids} 1291 | 1292 | def collate_fn_webdataset(examples): 1293 | examples = [preprocess_train(example) for example in examples] 1294 | pixel_values = torch.stack([example["pixel_values"][0] for example in examples]) 1295 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 1296 | input_ids = torch.stack([example["input_ids"] for example in examples]) 1297 | return {"pixel_values": pixel_values, "input_ids": input_ids} 1298 | 1299 | with accelerator.main_process_first(): 1300 | if url_train is None: 1301 | if args.max_train_samples is not None: 1302 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 1303 | train_dataset = dataset["train"].with_transform(preprocess_train) 1304 | train_dataloader = torch.utils.data.DataLoader( 1305 | train_dataset, 1306 | shuffle=True, 1307 | collate_fn=collate_fn, 1308 | batch_size=args.train_batch_size, 1309 | num_workers=args.dataloader_num_workers, 1310 | ) 1311 | if "validation" in dataset: 1312 | val_dataloader = torch.utils.data.DataLoader( 1313 | dataset["validation"].with_transform(preprocess_train), 1314 | shuffle=True, 1315 | collate_fn=collate_fn, 1316 | batch_size=args.train_batch_size, 1317 | num_workers=0, 1318 | ) 1319 | else: 1320 | val_dataloader = None 1321 | else: 1322 | dataset = {"train": wds.WebDataset(url_train).shuffle(args.webdataset_buffer_size).decode("pil", handler=wds.warn_and_continue).rename(captions="json",image="jpg",metadata="metadata.json",handler=wds.warn_and_continue,)} 1323 | train_dataloader = torch.utils.data.DataLoader(dataset["train"], collate_fn=collate_fn_webdataset, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers) 1324 | if url_val: 1325 | dataset["val"] = wds.WebDataset(url_val).shuffle(args.webdataset_buffer_size).decode("pil", handler=wds.warn_and_continue). rename(captions="json",image="jpg",metadata="metadata.json",handler=wds.warn_and_continue,) 1326 | val_dataloader = torch.utils.data.DataLoader(dataset["val"], collate_fn=collate_fn_webdataset, batch_size=args.train_batch_size, num_workers=0) 1327 | else: 1328 | val_dataloader = None 1329 | if args.max_train_samples is not None: 1330 | dataset = ( 1331 | wds.WebDataset(url_train, shardshuffle=True) 1332 | .shuffle(args.webdataset_buffer_size) 1333 | .decode() 1334 | ) 1335 | # Set the training transforms 1336 | train_dataset = dataset["train"] 1337 | 1338 | # Scheduler and math around the number of training steps. 1339 | overrode_max_train_steps = False 1340 | if url_train is None: 1341 | train_len = len(train_dataset) 1342 | train_dataloader_len = len(train_dataloader) 1343 | assert len(train_dataloader) == len(train_dataset) / args.train_batch_size, (len(train_dataloader), len(train_dataset)) 1344 | else: 1345 | train_dataloader_len = train_len / args.train_batch_size 1346 | 1347 | num_update_steps_per_epoch = math.ceil(train_dataloader_len / args.gradient_accumulation_steps) 1348 | 1349 | if args.max_train_steps is None: 1350 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 1351 | overrode_max_train_steps = True 1352 | 1353 | lr_scheduler = get_scheduler( 1354 | args.lr_scheduler, 1355 | optimizer=optimizer, 1356 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 1357 | num_training_steps=args.max_train_steps * accelerator.num_processes, 1358 | ) 1359 | 1360 | if not args.train_text_encoder: 1361 | unet.to(accelerator.device) 1362 | 1363 | # Prepare everything with our `accelerator`. 1364 | if val_dataloader is not None: 1365 | if args.train_text_encoder: 1366 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 1367 | unet, text_encoder, optimizer, train_dataloader,val_dataloader, lr_scheduler 1368 | ) 1369 | else: 1370 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 1371 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 1372 | ) 1373 | else: 1374 | if args.train_text_encoder: 1375 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 1376 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler 1377 | ) 1378 | else: 1379 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 1380 | unet, optimizer, train_dataloader, lr_scheduler 1381 | ) 1382 | 1383 | if args.use_ema: 1384 | ema_unet.to(accelerator.device) 1385 | 1386 | # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision 1387 | # as these weights are only used for inference, keeping weights in full precision is not required. 1388 | weight_dtype = torch.float32 1389 | if args.device != "hpu" and accelerator.mixed_precision == "fp16": 1390 | weight_dtype = torch.float16 1391 | args.mixed_precision = accelerator.mixed_precision 1392 | elif accelerator.mixed_precision == "bf16": 1393 | weight_dtype = torch.bfloat16 1394 | args.mixed_precision = accelerator.mixed_precision 1395 | elif args.device == "hpu" and gaudi_config.use_torch_autocast or args.bf16: 1396 | weight_dtype = torch.bfloat16 1397 | 1398 | # Move text_encode and vae to gpu and cast to weight_dtype 1399 | if not args.train_text_encoder and text_encoder is not None: 1400 | text_encoder.to(accelerator.device, dtype=weight_dtype) 1401 | vae.to(accelerator.device, dtype=weight_dtype) 1402 | 1403 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 1404 | num_update_steps_per_epoch = math.ceil(train_dataloader_len / args.gradient_accumulation_steps) 1405 | if overrode_max_train_steps: 1406 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 1407 | # Afterwards we recalculate our number of training epochs 1408 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 1409 | 1410 | # We need to initialize the trackers we use, and also store our configuration. 1411 | # The trackers initializes automatically on the main process. 1412 | if accelerator.is_main_process: 1413 | tracker_config = dict(vars(args)) 1414 | tracker_config.pop("validation_prompts") 1415 | accelerator.init_trackers(args.tracker_project_name, tracker_config) 1416 | 1417 | # Train! 1418 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 1419 | 1420 | logger.info("***** Running training *****") 1421 | logger.info(f" Num examples = {train_len}") 1422 | logger.info(f" Num Epochs = {args.num_train_epochs}") 1423 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 1424 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 1425 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 1426 | logger.info(f" Total optimization steps = {args.max_train_steps}") 1427 | global_step = 0 1428 | first_epoch = 0 1429 | 1430 | if accelerator.is_main_process: 1431 | run_config = vars(args) 1432 | with open(os.path.join(args.output_dir, "run_config.jsonl"), 'a') as f: 1433 | for key, value in run_config.items(): 1434 | json.dump({key: value}, f) 1435 | f.write('\n') 1436 | 1437 | # Potentially load in the weights and states from a previous save 1438 | if args.resume_from_checkpoint: 1439 | print(f"output_dir used for resuming is: {args.output_dir}") 1440 | if args.resume_from_checkpoint != "latest": 1441 | path = os.path.basename(args.resume_from_checkpoint) 1442 | else: 1443 | # Get the most recent checkpoint 1444 | dirs = os.listdir(args.output_dir) 1445 | dirs = [d for d in dirs if d.startswith("checkpoint")] 1446 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 1447 | path = dirs[-1] if len(dirs) > 0 else None 1448 | 1449 | if path is None: 1450 | accelerator.print( 1451 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 1452 | ) 1453 | args.resume_from_checkpoint = None 1454 | initial_global_step = 0 1455 | else: 1456 | accelerator.print(f"Resuming from checkpoint {path}") 1457 | accelerator.load_state(os.path.join(args.output_dir, path)) 1458 | global_step = int(path.split("-")[1]) 1459 | 1460 | initial_global_step = global_step 1461 | first_epoch = global_step // num_update_steps_per_epoch 1462 | 1463 | else: 1464 | initial_global_step = 0 1465 | 1466 | progress_bar = tqdm( 1467 | range(0, args.max_train_steps), 1468 | initial=initial_global_step, 1469 | desc="Steps", 1470 | # Only show the progress bar once on each machine. 1471 | disable=not accelerator.is_local_main_process, 1472 | ) 1473 | 1474 | if args.device == "hpu": 1475 | t0 = None 1476 | 1477 | # saving the model before training 1478 | if args.device == "hpu": 1479 | pipeline = GaudiStableDiffusionPipeline.from_pretrained( 1480 | args.pretrained_model_name_or_path, 1481 | text_encoder=unwrap_model(text_encoder) if args.train_text_encoder else text_encoder, 1482 | vae=vae, 1483 | unet=unwrap_model(unet), 1484 | revision=args.revision, 1485 | scheduler=noise_scheduler, 1486 | ) 1487 | else: 1488 | pipeline = StableDiffusionPipeline.from_pretrained( 1489 | args.pretrained_model_name_or_path, 1490 | text_encoder=unwrap_model(text_encoder) if args.train_text_encoder else text_encoder, 1491 | vae=vae, 1492 | unet= unwrap_model(unet), 1493 | revision=args.revision, 1494 | ) 1495 | pipeline.save_pretrained(args.output_dir) 1496 | if accelerator.is_main_process: 1497 | log_validation( 1498 | vae, 1499 | text_encoder, 1500 | tokenizer, 1501 | unet, 1502 | args, 1503 | accelerator, 1504 | weight_dtype, 1505 | global_step, 1506 | val_dataloader, 1507 | noise_scheduler, 1508 | ) 1509 | text_train_active = False 1510 | for epoch in range(first_epoch, args.num_train_epochs): 1511 | unet.train() 1512 | if args.train_text_encoder and global_step > args.freeze_text_encoder_steps: 1513 | text_encoder.train() 1514 | train_loss = 0.0 1515 | print("epoch: ", epoch) 1516 | for step, batch in enumerate(train_dataloader): 1517 | if args.train_text_encoder and global_step > args.freeze_text_encoder_steps and not text_train_active: 1518 | text_encoder.train() 1519 | text_train_active = True 1520 | print("Text encoder training started at {} steps".format(global_step)) 1521 | 1522 | if args.device == "hpu": 1523 | if t0 is None and global_step == args.throughput_warmup_steps: 1524 | t0 = time.perf_counter() 1525 | 1526 | with accelerator.accumulate(unet): 1527 | # Convert images to latent space 1528 | latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample() 1529 | latents = latents * vae.config.scaling_factor 1530 | 1531 | # Sample noise that we'll add to the latents 1532 | noise = torch.randn_like(latents) 1533 | if args.noise_offset: 1534 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise 1535 | noise += args.noise_offset * torch.randn( 1536 | (latents.shape[0], latents.shape[1], 1, 1), device=latents.device 1537 | ) 1538 | if args.input_perturbation: 1539 | new_noise = noise + args.input_perturbation * torch.randn_like(noise) 1540 | bsz = latents.shape[0] 1541 | # Sample a random timestep for each image 1542 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 1543 | timesteps = timesteps.long() 1544 | 1545 | # Add noise to the latents according to the noise magnitude at each timestep 1546 | # (this is the forward diffusion process) 1547 | if args.input_perturbation: 1548 | noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) 1549 | else: 1550 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 1551 | 1552 | # Get the text embedding for conditioning 1553 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 1554 | 1555 | # Get the target for loss depending on the prediction type 1556 | if args.prediction_type is not None: 1557 | # set prediction_type of scheduler if defined 1558 | noise_scheduler.register_to_config(prediction_type=args.prediction_type) 1559 | 1560 | if noise_scheduler.config.prediction_type == "epsilon": 1561 | target = noise 1562 | elif noise_scheduler.config.prediction_type == "v_prediction": 1563 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 1564 | else: 1565 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 1566 | 1567 | # write prediction_type to run_config.json 1568 | if accelerator.is_main_process and epoch == 0 and step == 0: 1569 | with open(os.path.join(args.output_dir, "run_config.jsonl"), 'a') as f: 1570 | f.write(json.dumps({"prediction_type": noise_scheduler.config.prediction_type}) + '\n') 1571 | 1572 | # Predict the noise residual and compute loss 1573 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 1574 | 1575 | if args.snr_gamma is None: 1576 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 1577 | else: 1578 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. 1579 | # Since we predict the noise instead of x_0, the original formulation is slightly changed. 1580 | # This is discussed in Section 4.2 of the same paper. 1581 | snr = compute_snr(noise_scheduler, timesteps) 1582 | if noise_scheduler.config.prediction_type == "v_prediction": 1583 | # Velocity objective requires that we add one to SNR values before we divide by them. 1584 | snr = snr + 1 1585 | mse_loss_weights = ( 1586 | torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr 1587 | ) 1588 | 1589 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 1590 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights 1591 | loss = loss.mean() 1592 | 1593 | # Gather the losses across all processes for logging (if we use distributed training). 1594 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 1595 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 1596 | 1597 | # Backpropagate 1598 | accelerator.backward(loss) 1599 | if accelerator.sync_gradients: 1600 | params_to_clip = ( 1601 | itertools.chain(unet.parameters(), text_encoder.parameters()) 1602 | if args.train_text_encoder and global_step > args.freeze_text_encoder_steps 1603 | else unet.parameters() 1604 | ) 1605 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 1606 | optimizer.step() 1607 | lr_scheduler.step() 1608 | 1609 | if args.device == "hpu": 1610 | optimizer.zero_grad(set_to_none=True) 1611 | htcore.mark_step() 1612 | else: 1613 | optimizer.zero_grad() 1614 | # Checks if the accelerator has performed an optimization step behind the scenes 1615 | if accelerator.sync_gradients: 1616 | if args.use_ema: 1617 | ema_unet.step(unet.parameters()) 1618 | progress_bar.update(1) 1619 | global_step += 1 1620 | accelerator.log({"training/train_loss": train_loss}, step=global_step) 1621 | accelerator.log({"hyperparameters/batch_size": args.train_batch_size}, step=global_step) 1622 | accelerator.log({"hyperparameters/effective_batch_size": total_batch_size}, step=global_step) 1623 | accelerator.log({"hyperparameters/learning_rate": lr_scheduler.get_last_lr()[0]}, step=global_step) 1624 | train_loss = 0.0 1625 | 1626 | if global_step % args.checkpointing_steps == 0: 1627 | if accelerator.is_main_process: 1628 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 1629 | if args.checkpoints_total_limit is not None: 1630 | checkpoints = os.listdir(args.output_dir) 1631 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 1632 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 1633 | 1634 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 1635 | if len(checkpoints) >= args.checkpoints_total_limit: 1636 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 1637 | removing_checkpoints = checkpoints[0:num_to_remove] 1638 | 1639 | logger.info( 1640 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 1641 | ) 1642 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") 1643 | 1644 | for removing_checkpoint in removing_checkpoints: 1645 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 1646 | shutil.rmtree(removing_checkpoint) 1647 | 1648 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 1649 | accelerator.save_state(save_path) 1650 | logger.info(f"Saved state to {save_path}") 1651 | 1652 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 1653 | progress_bar.set_postfix(**logs) 1654 | 1655 | if global_step >= args.max_train_steps: 1656 | break 1657 | 1658 | if args.device == "hpu": 1659 | #gg gaudi addition 1660 | duration = time.perf_counter() - t0 1661 | throughput = args.max_train_steps * total_batch_size / duration 1662 | 1663 | if accelerator.is_main_process: 1664 | if args.device == "hpu": 1665 | logger.info(f"Throughput = {throughput} samples/s") 1666 | logger.info(f"Train runtime = {duration} seconds") 1667 | metrics = { 1668 | "train_samples_per_second": throughput, 1669 | "train_runtime": duration, 1670 | } 1671 | with open(f"{args.output_dir}/speed_metrics.json", mode="w") as file: 1672 | json.dump(metrics, file) 1673 | 1674 | if epoch % args.validation_epochs == 0: 1675 | if args.use_ema: 1676 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference. 1677 | ema_unet.store(unet.parameters()) 1678 | ema_unet.copy_to(unet.parameters()) 1679 | log_validation( 1680 | vae, 1681 | text_encoder, 1682 | tokenizer, 1683 | unet, 1684 | args, 1685 | accelerator, 1686 | weight_dtype, 1687 | global_step, 1688 | val_dataloader, 1689 | noise_scheduler, 1690 | ) 1691 | if args.use_ema: 1692 | # Switch back to the original UNet parameters. 1693 | ema_unet.restore(unet.parameters()) 1694 | 1695 | # Create the pipeline using the trained modules and save it. 1696 | accelerator.wait_for_everyone() 1697 | if accelerator.is_main_process: 1698 | if args.device == "hpu": 1699 | logger.info(f"Throughput = {throughput} samples/s") 1700 | logger.info(f"Train runtime = {duration} seconds") 1701 | metrics = { 1702 | "train_samples_per_second": throughput, 1703 | "train_runtime": duration, 1704 | } 1705 | with open(f"{args.output_dir}/speed_metrics.json", mode="w") as file: 1706 | json.dump(metrics, file) 1707 | 1708 | unet = accelerator.unwrap_model(unet) 1709 | if args.train_text_encoder: 1710 | text_encoder = unwrap_model(text_encoder) 1711 | if args.use_ema: 1712 | ema_unet.copy_to(unet.parameters()) 1713 | 1714 | if args.device == "hpu": 1715 | pipeline = GaudiStableDiffusionPipeline.from_pretrained( 1716 | args.pretrained_model_name_or_path, 1717 | text_encoder=text_encoder, 1718 | vae=vae, 1719 | unet=unet, 1720 | revision=args.revision, 1721 | scheduler=noise_scheduler, 1722 | ) 1723 | else: 1724 | pipeline = StableDiffusionPipeline.from_pretrained( 1725 | args.pretrained_model_name_or_path, 1726 | text_encoder=text_encoder, 1727 | vae=vae, 1728 | unet=unet, 1729 | revision=args.revision, 1730 | ) 1731 | pipeline.save_pretrained(args.output_dir) 1732 | if args.use_ema: 1733 | ema_unet.save_pretrained(os.path.join(args.output_dir, "unet_ema")) 1734 | 1735 | # Run a final round of inference. 1736 | images = [] 1737 | if args.validation_prompts is not None: 1738 | logger.info("Running inference for collecting generated images...") 1739 | pipeline = pipeline.to(accelerator.device) 1740 | pipeline.torch_dtype = weight_dtype 1741 | pipeline.set_progress_bar_config(disable=True) 1742 | 1743 | if args.enable_xformers_memory_efficient_attention: 1744 | pipeline.enable_xformers_memory_efficient_attention() 1745 | 1746 | if args.seed is None: 1747 | generator = None 1748 | else: 1749 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 1750 | 1751 | for i in range(len(args.validation_prompts)): 1752 | if args.device == "hpu": 1753 | image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] 1754 | else: 1755 | with torch.autocast("cuda"): 1756 | image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0] 1757 | images.append(image) 1758 | 1759 | # name of last directory in output path 1760 | run_name = os.path.basename(os.path.normpath(args.output_dir)) 1761 | #save_model_card(args, None, images, repo_folder=args.output_dir) 1762 | if args.push_to_hub: 1763 | #save_model_card(args, repo_id, images, repo_folder=args.output_dir) 1764 | upload_folder( 1765 | repo_id=repo_id, 1766 | folder_path=args.output_dir, 1767 | path_in_repo=f"gpu_runs/{run_name}", 1768 | commit_message="End of training", 1769 | ignore_patterns=["step_*", "epoch_*, checkpoint-*"], 1770 | allow_patterns=["checkpoint-15000"] 1771 | ) 1772 | 1773 | accelerator.end_training() 1774 | 1775 | 1776 | if __name__ == "__main__": 1777 | main() 1778 | --------------------------------------------------------------------------------