├── LICENSE
├── Makefile
├── README.md
├── assets
└── spright_good-1.png
├── eval
├── README.md
└── gpt4
│ ├── README.md
│ ├── collate_gpt4_results.py
│ ├── create_eval_dataset.py
│ ├── eval_with_gpt4.py
│ ├── push_eval_dataset_to_hub.py
│ └── requirements.txt
├── pyproject.toml
└── training
├── README.md
├── requirements.txt
├── spright_t2i_multinode_example.sh
└── train.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 |
2 | check_dirs := .
3 |
4 | quality:
5 | ruff check $(check_dirs)
6 | ruff format --check $(check_dirs)
7 |
8 | style:
9 | ruff check $(check_dirs) --fix
10 | ruff format $(check_dirs)
11 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SPRIGHT 🖼️✨
2 |
3 | Welcome to the official GitHub repository for our paper titled "Getting it Right: Improving Spatial Consistency in Text-to-Image Models". Our work introduces a simple approach to enhance spatial consistency in text-to-image diffusion models, alongside a high-quality dataset designed for this purpose.
4 |
5 | **_Getting it Right: Improving Spatial Consistency in Text-to-Image Models_** by Agneet Chatterjee$, Gabriela Ben Melech Stan$, Estelle Aflalo, Sayak Paul, Dhruba Ghosh, Tejas Gokhale, Ludwig Schmidt, Hannaneh Hajishirzi, Vasudev Lal, Chitta Baral, Yezhou Yang.
6 |
7 | $ denotes equal contribution.
8 |
9 |
10 | 🤗 Models & Datasets | 📃 Paper |
11 | ⚙️ Demo |
12 | 🎮 Project Website
13 |
14 |
15 | Update July 05, 2024: We got accepted to ECCV'24 🥳
16 |
17 | ## 📄 Abstract
18 | _One of the key shortcomings in current text-to-image (T2I) models is their inability to consistently generate images which faithfully follow the spatial relationships specified in the text prompt. In this paper, we offer a comprehensive investigation of this limitation, while also developing datasets and methods that achieve state-of-the-art performance. First, we find that current vision-language datasets do not represent spatial relationships well enough; to alleviate this bottleneck, we create SPRIGHT, the first spatially-focused, large scale dataset, by re-captioning 6 million images from 4 widely used vision datasets. Through a 3-fold evaluation and analysis pipeline, we find that SPRIGHT largely improves upon existing datasets in capturing spatial relationships. To demonstrate its efficacy, we leverage only \~0.25% of SPRIGHT and achieve a 22% improvement in generating spatially accurate images while also improving the FID and CMMD scores. Secondly, we find that training on images containing a large number of objects results in substantial improvements in spatial consistency. Notably, we attain state-of-the-art on T2I-CompBench with a spatial score of 0.2133, by fine-tuning on <500 images. Finally, through a set of controlled experiments and ablations, we document multiple findings that we believe will enhance the understanding of factors that affect spatial consistency in text-to-image models. We publicly release our dataset and
19 | model to foster further research in this area._
20 |
21 | ## 📚 Contents
22 | - [Installation](#installation)
23 | - [Training](#training)
24 | - [Inference](#inference)
25 | - [The SPRIGHT Dataset](#the-spright-dataset)
26 | - [Eval](#evaluation)
27 | - [Citing](#citing)
28 | - [Acknowledgments](#ack)
29 |
30 |
31 | ## 💾 Installation
32 |
33 | Make sure you have CUDA and PyTorch set up. The PyTorch [official documentation](https://pytorch.org/) is the best place to refer to for that. Rest of the installation instructions are provided in the respective sections.
34 |
35 | If you have access to the Habana Gaudi accelerators, you can benefit from them as our training script supports them.
36 |
37 |
38 | ## 🔍 Training
39 |
40 | Refer to [`training/`](./training).
41 |
42 |
43 | ## 🌺 Inference
44 |
45 | ```python
46 | from diffusers import DiffusionPipeline
47 | import torch
48 |
49 | spright_id = "SPRIGHT-T2I/spright-t2i-sd2"
50 | pipe = DiffusionPipeline.from_pretrained(spright_id, torch_dtype=torch.float16).to("cuda")
51 |
52 | image = pipe("A horse above a pizza").images[0]
53 | image
54 | ```
55 |
56 | You can also run [the demo](https://huggingface.co/spaces/SPRIGHT-T2I/SPRIGHT-T2I) locally:
57 |
58 | ```bash
59 | git clone https://huggingface.co/spaces/SPRIGHT-T2I/SPRIGHT-T2I
60 | cd SPRIGHT-T2I
61 | python app.py
62 | ```
63 |
64 | Make sure `gradio` and other dependencies are installed in your environment.
65 |
66 |
67 | ## 🖼️ The SPRIGHT Dataset
68 |
69 | Refer to our [paper](https://arxiv.org/abs/2404.01197) and [the dataset page](https://huggingface.co/datasets/SPRIGHT-T2I/spright) for more details. Below are some examples from the SPRIGHT dataset:
70 |
71 |
72 |
73 |
74 |
75 |
76 | ## 📊 Evaluation
77 |
78 | In the [`eval/`](./eval) directory, we provide details about the various evaluation methods we use in our work .
79 |
80 |
81 | ## 📜 Citing
82 |
83 | ```bibtex
84 | @misc{chatterjee2024getting,
85 | title={Getting it Right: Improving Spatial Consistency in Text-to-Image Models},
86 | author={Agneet Chatterjee and Gabriela Ben Melech Stan and Estelle Aflalo and Sayak Paul and Dhruba Ghosh and Tejas Gokhale and Ludwig Schmidt and Hannaneh Hajishirzi and Vasudev Lal and Chitta Baral and Yezhou Yang},
87 | year={2024},
88 | eprint={2404.01197},
89 | archivePrefix={arXiv},
90 | primaryClass={cs.CV}
91 | }
92 | ```
93 |
94 |
95 | ## 🙏 Acknowledgments
96 |
97 | We thank Lucain Pouget for helping us in uploading the dataset to the Hugging Face Hub and the Hugging Face team for providing computing resources to host our demo. The authors acknowledge resources and support from the Research Computing facilities at Arizona State University. This work was supported by NSF RI grants \#1750082 and \#2132724. The views and opinions of the authors expressed herein do not necessarily state or reflect those of the funding agencies and employers.
98 |
--------------------------------------------------------------------------------
/assets/spright_good-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SPRIGHT-T2I/SPRIGHT/1ce84339c7aceaf502a3d0959d612de8b8ac0171/assets/spright_good-1.png
--------------------------------------------------------------------------------
/eval/README.md:
--------------------------------------------------------------------------------
1 | ## Quantitative evals
2 |
3 | Here we list out the evaluations we carried out in our work and their corresponding code implementations.
4 |
5 | | **Eval** | **Implementation** |
6 | |:------------:|:------------:|
7 | | GPT-4 (V) | [`gpt4/`](./gpt4/) |
8 | | FAITHScore | [Official implementation](https://github.com/bcdnlp/FAITHSCORE) |
9 | | VISOR | [Official implementation](https://github.com/microsoft/VISOR) |
10 | | T2I-CompBench | [Official implementation](https://github.com/Karine-Huang/T2I-CompBench) |
11 | | GenEval | [Official implementation](https://github.com/djghosh13/geneval) |
12 | | CMMD | [Official implementation](https://github.com/google-research/google-research/tree/master/cmmd) |
13 | | FID | [Implementation](https://github.com/mseitzer/pytorch-fid) |
14 | | CKA | [Implementation](https://github.com/AntixK/PyTorch-Model-Compare) |
15 | | Attention Maps | [Official implementation](https://github.com/yuval-alaluf/Attend-and-Excite) |
16 |
--------------------------------------------------------------------------------
/eval/gpt4/README.md:
--------------------------------------------------------------------------------
1 | # Automated quality evaluation with GPT-4
2 |
3 | This directory provides scripts for performing an automated quality evaluation with GPT-4. So, you need an OpenAI API key (with at least $5 credits in it) to run the `eval_with_gpt4.py` script.
4 |
5 | ## Dataset
6 |
7 | The dataset consists of 100 images which can be found [here](https://hf.co/datasets/SPRIGHT-T2I/100-images-for-eval). It was created from the [SAM](https://segment-anything.com/) and [CC12M](https://github.com/google-research-datasets/conceptual-12m) datasets.
8 |
9 | To create the dataset run (make sure the dependencies are installed first):
10 |
11 | ```bash
12 | python create_eval_dataset.py
13 | python push_eval_dataset_to_hub.py
14 | ```
15 |
16 | _(Run `huggingface-cli login` before running `python push_eval_dataset_to_hub.py`. You might have to change the `ds_id` in the script as well.)_
17 |
18 | ## Evaluation
19 |
20 | ```bash
21 | python eval_with_gpt4.py
22 | ```
23 |
24 | The script comes with limited support for handling rate-limiting issues.
25 |
26 | ## Collating GPT-4 results
27 |
28 | Once `python eval_with_gpt4.py` has been run it should produce JSON files prefixed with `gpt4_evals`. You can then run the following to collate the results and push the final dataset to the HF Hub for auditing:
29 |
30 | ```bash
31 | python collate_gpt4_results.py
32 | ```
33 |
--------------------------------------------------------------------------------
/eval/gpt4/collate_gpt4_results.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import json
3 |
4 | from datasets import Dataset, Features, Value, load_dataset
5 | from datasets import Image as ImageFeature
6 |
7 |
8 | def sort_file_paths(file_paths):
9 | # Extract the starting id more accurately and use it as the key for sorting
10 | sorted_paths = sorted(file_paths, key=lambda x: int(x.split("_")[2]))
11 | return sorted_paths
12 |
13 |
14 | def get_ratings_from_json(json_path):
15 | all_ratings = []
16 | with open(json_path, "r") as f:
17 | json_dict = json.load(f)
18 | for i in range(len(json_dict)):
19 | all_ratings.append(json_dict[i])
20 | return all_ratings
21 |
22 |
23 | all_jsons = sorted(glob.glob("*.json"))
24 | sorted_all_jsons = sort_file_paths(all_jsons)
25 |
26 | all_ratings = []
27 | for json_path in sorted_all_jsons:
28 | try:
29 | all_ratings.extend(get_ratings_from_json(json_path))
30 | except:
31 | print(json_path)
32 |
33 | eval_dataset = load_dataset("ASU-HF/100-images-for-eval", split="train")
34 |
35 |
36 | def generation_fn():
37 | for i in range(len(eval_dataset)):
38 | yield {
39 | "image": eval_dataset[i]["image"],
40 | "spatial_caption": eval_dataset[i]["spatial_caption"],
41 | "gpt4_rating": all_ratings[i]["rating"],
42 | "gpt4_explanation": all_ratings[i]["explanation"],
43 | }
44 |
45 |
46 | ds = Dataset.from_generator(
47 | generation_fn,
48 | features=Features(
49 | image=ImageFeature(),
50 | spatial_caption=Value("string"),
51 | gpt4_rating=Value("int32"),
52 | gpt4_explanation=Value("string"),
53 | ),
54 | )
55 | ds_id = "ASU-HF/gpt4-evaluation"
56 | ds.push_to_hub(ds_id)
57 |
--------------------------------------------------------------------------------
/eval/gpt4/create_eval_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | import shutil
5 |
6 | from tqdm.auto import tqdm
7 |
8 |
9 | random.seed(2024)
10 |
11 |
12 | JSON_PATHS = ["cc12m/spatial_prompts_cc_res768.jsonl", "sa/spatial_prompts_sa_res768.jsonl"]
13 | CUT_OFF_FOR_EACH = 50
14 | SUBSET_DIR = "eval"
15 | ROOT_PATH = "/home/jupyter/test-images-spatial/human_eval_subset"
16 |
17 |
18 | def copy_images(tuple_entries, subset):
19 | final_dict = {}
20 | for entry in tqdm(tuple_entries):
21 | image_name = entry[0].split("/")[-1]
22 | image_to_copy_from = os.path.join(ROOT_PATH, subset, "images", image_name)
23 | image_to_copy_to = os.path.join(ROOT_PATH, SUBSET_DIR)
24 | shutil.copy(image_to_copy_from, image_to_copy_to)
25 | final_dict[image_name] = entry[1]
26 | return final_dict
27 |
28 |
29 | # Load the JSON files.
30 | cc12m_entries = []
31 | with open(JSON_PATHS[0], "rb") as json_list:
32 | for json_str in json_list:
33 | cc12m_entries.append(json.loads(json_str))
34 |
35 | sa_entries = []
36 | with open(JSON_PATHS[1], "rb") as json_list:
37 | for json_str in json_list:
38 | sa_entries.append(json.loads(json_str))
39 |
40 | # Prepare tuples and shuffle them for random sampling.
41 | print(len(cc12m_entries), len(sa_entries))
42 | cc12m_tuples = [(line["file_name"], line["spatial_caption"]) for line in cc12m_entries]
43 | sa_tuples = [(line["file_name"], line["spatial_caption"]) for line in sa_entries]
44 | filtered_cc12m_tuples = [
45 | (line[0], line[1])
46 | for line in cc12m_tuples
47 | if os.path.exists(os.path.join(ROOT_PATH, "cc12m", "images", line[0].split("/")[-1]))
48 | ]
49 |
50 | # Keep paths that exist.
51 | filtered_sa_tuples = [
52 | (line[0], line[1])
53 | for line in sa_tuples
54 | if os.path.exists(os.path.join(ROOT_PATH, "sa", "images", line[0].split("/")[-1]))
55 | ]
56 | print(len(filtered_cc12m_tuples), len(filtered_sa_tuples))
57 | random.shuffle(filtered_cc12m_tuples)
58 | random.shuffle(filtered_sa_tuples)
59 |
60 | # Cut off for subsets.
61 | subset_cc12m_tuples = filtered_cc12m_tuples[:CUT_OFF_FOR_EACH]
62 | subset_sa_tuples = filtered_sa_tuples[:CUT_OFF_FOR_EACH]
63 |
64 | # Copy over the images.
65 | if not os.path.exists(SUBSET_DIR):
66 | os.makedirs(SUBSET_DIR, exist_ok=True)
67 |
68 | final_data_dict = {}
69 | cc12m_dict = copy_images(subset_cc12m_tuples, "cc12m")
70 | sa_dict = copy_images(subset_sa_tuples, "sa")
71 | print(len(cc12m_dict), len(sa_dict))
72 | final_data_dict = {**cc12m_dict, **sa_dict}
73 |
74 | # Create a json file to record metadata.
75 | with open("final_data_dict.json", "w") as f:
76 | json.dump(final_data_dict, f)
77 |
--------------------------------------------------------------------------------
/eval/gpt4/eval_with_gpt4.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | import os
4 | import time
5 | from concurrent.futures import ThreadPoolExecutor
6 | from io import BytesIO
7 |
8 | import requests
9 | from datasets import load_dataset
10 |
11 |
12 | api_key = os.getenv("OPENAI_API_KEY")
13 | headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
14 |
15 |
16 | def encode_image(image):
17 | buffered = BytesIO()
18 | image.save(buffered, format="JPEG")
19 | img_str = base64.b64encode(buffered.getvalue())
20 | return img_str.decode("utf-8")
21 |
22 |
23 | def create_payload(image_string, caption):
24 | # System message adapted from DALL-E 3 (as found in the tech report):
25 | # https://cdn.openai.com/papers/dall-e-3.pdf
26 | messages = [
27 | {
28 | "role": "system",
29 | "content": """You are part of a team of bots that evaluates images and their captions. Your job is to come up with a rating in between 1 to 10 to evaluate the provided caption for the provided image. While performing the assessment, consider the correctness of spatial relationships captured in the provided image. You should return the response formatted as a dictionary having two keys: 'rating', denoting the numeric rating and 'explanation', denoting a brief justification for the rating.
30 |
31 | The captions you are judging are designed to stress - test image captioning programs, and may include things such as:
32 | 1. Spatial phrases like above, below, left, right, front, behind, background, foreground (focus most on the correctness of these words )
33 | 2. Relative sizes between objects such as small & large, big & tiny (focus on the correctness of these words)
34 | 3. Scrambled or mis - spelled words (the image generator should an image associated with the probably meaning).
35 |
36 | You need to make a decision as to whether or not the caption is correct, given the image.
37 |
38 | A few rules :
39 | 1. It is ok if the caption does not explicitly mention each object in the image; as long as the caption is correct in its entirety, it is fine.
40 | 2. It also ok if some captions dont have spatial relationships; judge them based on their correctness. A caption not containing spatial relationships should not be penalized.
41 | 3. You will think out loud about your eventual conclusion. Don't include your reasoning in the final output.
42 | 4. You should return the response formatted as a Python-formatted dictionary having
43 | two keys: 'rating', denoting the numeric rating and 'explanation', denoting
44 | a brief justification for the rating.
45 | """,
46 | }
47 | ]
48 | messages.append(
49 | {
50 | "role": "user",
51 | "content": [
52 | {
53 | "type": "text",
54 | "text": "Come with a rating in between 1 to 10 to evaluate the provided caption for the provided image."
55 | " While performing the assessment, consider the correctness of spatial relationships captured in the provided image."
56 | f" Caption provided: {caption}",
57 | },
58 | {
59 | "type": "image_url",
60 | "image_url": {"url": f"data:image/jpeg;base64,{image_string}"},
61 | },
62 | ],
63 | }
64 | )
65 |
66 | payload = {
67 | "model": "gpt-4-vision-preview",
68 | "messages": messages,
69 | "max_tokens": 250,
70 | "seed": 2024,
71 | }
72 | return payload
73 |
74 |
75 | def get_response(image_string, caption):
76 | payload = create_payload(image_string, caption)
77 | response = requests.post("https://api.openai.com/v1/chat/completions", headers=headers, json=payload)
78 | return response.json()
79 |
80 |
81 | def get_rating(response):
82 | content = response["choices"][0]["message"]["content"]
83 | # The following clean-up is a bit of a shout in the void and the oblivion is inevitable.
84 | cleaned_content = (
85 | content.strip().replace("```", "").replace("json", "").replace("python", "").strip().replace("\n", "")
86 | )
87 | return cleaned_content
88 |
89 |
90 | dataset = load_dataset("ASU-HF/100-images-for-eval", split="train")
91 | image_strings = []
92 | captions = []
93 | for i in range(len(dataset)):
94 | image_strings.append(encode_image(dataset[i]["image"]))
95 | captions.append(dataset[i]["spatial_caption"])
96 |
97 |
98 | chunk_size = 8
99 | json_retry = 4
100 | per_min_token_limit = 10000
101 | per_day_request_limit = 500
102 | total_requests_made = 0
103 | batch_total_tokens = 0
104 |
105 | with ThreadPoolExecutor(chunk_size) as e:
106 | for i in range(0, len(image_strings), chunk_size):
107 | responses = None
108 | cur_retry = 0
109 |
110 | # request handling with retries
111 | while responses is None and cur_retry <= json_retry:
112 | try:
113 | responses = list(e.map(get_response, image_strings[i : i + chunk_size], captions[i : i + chunk_size]))
114 | except Exception as e:
115 | cur_retry = cur_retry + 1
116 | continue
117 |
118 | # handle rate-limits
119 | total_requests_made += len(image_strings[i : i + chunk_size])
120 | for response in responses:
121 | batch_total_tokens += response["usage"]["total_tokens"]
122 |
123 | with open(f"gpt4_evals_{i}_to_{(i + chunk_size) - 1}.json", "w") as f:
124 | ratings = [eval(get_rating(response)) for response in responses]
125 | json.dump(ratings, f, indent=4)
126 |
127 | if total_requests_made > per_day_request_limit:
128 | total_requests_made = 0
129 | time.sleep(86400) # wait a day!
130 | elif batch_total_tokens > per_min_token_limit:
131 | batch_total_tokens = 0
132 | time.sleep(1800) # wait for half an hour to prevent per_min_request_limit
133 |
--------------------------------------------------------------------------------
/eval/gpt4/push_eval_dataset_to_hub.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | from datasets import Dataset, Features, Value
5 | from datasets import Image as ImageFeature
6 |
7 |
8 | final_dict_path = "final_data_dict.json"
9 | with open(final_dict_path, "r") as f:
10 | final_dict = json.load(f)
11 |
12 |
13 | root_path = "/home/jupyter/test-images-spatial/human_eval_subset/eval"
14 |
15 |
16 | def generation_fn():
17 | for k in final_dict:
18 | yield {
19 | "image": os.path.join(root_path, k),
20 | "spatial_caption": final_dict[k],
21 | "subset": "SA" if "sa" in k else "CC12M",
22 | }
23 |
24 |
25 | ds = Dataset.from_generator(
26 | generation_fn,
27 | features=Features(
28 | image=ImageFeature(),
29 | spatial_caption=Value("string"),
30 | subset=Value("string"),
31 | ),
32 | )
33 | ds_id = "ASU-HF/100-images-for-eval"
34 | ds.push_to_hub(ds_id)
35 |
--------------------------------------------------------------------------------
/eval/gpt4/requirements.txt:
--------------------------------------------------------------------------------
1 | datasets
2 | openai
3 | Pillow
4 | tqdm
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 |
2 | [tool.ruff]
3 | # Never enforce `E501` (line length violations).
4 | ignore = ["C901", "E501", "E741", "F402", "F823"]
5 | select = ["C", "E", "F", "I", "W"]
6 | line-length = 119
7 |
8 | # Ignore import violations in all `__init__.py` files.
9 | [tool.ruff.per-file-ignores]
10 | "__init__.py" = ["E402", "F401", "F403", "F811"]
11 | "src/diffusers/utils/dummy_*.py" = ["F401"]
12 |
13 | [tool.ruff.isort]
14 | lines-after-imports = 2
15 | known-first-party = ["diffusers"]
16 |
17 | [tool.ruff.format]
18 | # Like Black, use double quotes for strings.
19 | quote-style = "double"
20 |
21 | # Like Black, indent with spaces, rather than tabs.
22 | indent-style = "space"
23 |
24 | # Like Black, respect magic trailing commas.
25 | skip-magic-trailing-comma = false
26 |
27 | # Like Black, automatically detect the appropriate line ending.
28 | line-ending = "auto"
29 |
--------------------------------------------------------------------------------
/training/README.md:
--------------------------------------------------------------------------------
1 | ## Training with the SPRIGHT dataset
2 |
3 | If you're on CUDA, then make sure it's properly set up and install PyTorch following instructions from its official documentation.
4 |
5 | If you've access Habana Gaudi accelerators and wish to use them for training then first get `habana` set up, following the [official website](https://docs.habana.ai/en/latest/Installation_Guide/index.html#gaudi-installation-guide). Then install `optimum`:
6 |
7 | ```bash
8 | pip install git+https://github.com/huggingface/optimum-habana.git
9 | ```
10 |
11 | Other training-related Python dependencies are found in [`requirements.txt`](./requirements.txt).
12 |
13 | ### Data preparation
14 |
15 | In order to work on our dataset,
16 |
17 | - Download the dataset from [here](https://huggingface.co/datasets/SPRIGHT-T2I/spright) and place it under /path/to/spright
18 | - The structure of the downloaded repository is as followed:
19 |
20 | ```plaintext
21 | /path/to/spright/
22 | ├── data/
23 | │ └── *.tar
24 | ├── metadata.json
25 | ├── load_data.py
26 | └── robust_upload.py
27 | ```
28 | - Each .tar file contains aounrd 10k images with associated general and spatial captions.
29 | - `metadata.json` contains the nature of the split for each tar file, as well as the number of samples per .tar file.
30 |
31 | ### Example training command
32 | #### Multiple GPUs
33 |
34 | 1. In order to finetune our model using the train and validation splits as set by [SPRIGHT data](https://github.com/SPRIGHT-T2I/SPRIGHT#data-preparation) in `metadata.json`:
35 | ```bash
36 | export MODEL_NAME="SPRIGHT-T2I/spright-t2i-sd2"
37 | export OUTDIR="path/to/outdir"
38 | export SPRIGHT_SPLIT="path/to/spright/metadata.json" # download from: https://huggingface.co/datasets/SPRIGHT-T2I/spright/blob/main/metadata.json
39 |
40 | accelerate launch --mixed_precision="fp16" train.py \
41 | --pretrained_model_name_or_path=$MODEL_NAME \
42 | --use_ema \
43 | --resolution=768 --center_crop --random_flip \
44 | --train_batch_size=4 \
45 | --gradient_accumulation_steps=1 \
46 | --max_train_steps=15000 \
47 | --learning_rate=5e-05 \
48 | --max_grad_norm=1 \
49 | --lr_scheduler="constant" \
50 | --lr_warmup_steps=0 \
51 | --output_dir=$OUTDIR \
52 | --validation_epochs 1 \
53 | --checkpointing_steps=1500 \
54 | --freeze_text_encoder_steps 0 \
55 | --train_text_encoder \
56 | --text_encoder_lr=1e-06 \
57 | --spright_splits $SPRIGHT_SPLIT
58 | ```
59 | 2. It is possible to set the train/val splits manually, by specifying the particular *.tar files using `--spright_train_costum` for training and `--spright_val_costum` for validation. `metadata.json` should also be passed to the training command, as it provides the count of samples in each .tar file:
60 | ```bash
61 | export MODEL_NAME="SPRIGHT-T2I/spright-t2i-sd2"
62 | export OUTDIR="path/to/outdir"
63 | export WEBDATA_TRAIN="path/to/spright/data/{00000..00004}.tar"
64 | export WEBDATA_VAL="path/to/spright/data/{00004..00005}.tar"
65 | export SPRIGHT_SPLIT="path/to/spright/metadata.json" # download from: https://huggingface.co/datasets/SPRIGHT-T2I/spright/blob/main/metadata.json
66 |
67 | accelerate launch --mixed_precision="fp16" train.py \
68 | --pretrained_model_name_or_path=$MODEL_NAME \
69 | --use_ema \
70 | --resolution=768 --center_crop --random_flip \
71 | --train_batch_size=4 \
72 | --gradient_accumulation_steps=1 \
73 | --max_train_steps=15000 \
74 | --learning_rate=5e-05 \
75 | --max_grad_norm=1 \
76 | --lr_scheduler="constant" \
77 | --lr_warmup_steps=0 \
78 | --output_dir=$OUTDIR \
79 | --validation_epochs 1 \
80 | --checkpointing_steps=1500 \
81 | --freeze_text_encoder_steps 0 \
82 | --train_text_encoder \
83 | --text_encoder_lr=1e-06 \
84 | --spright_splits $SPRIGHT_SPLIT \
85 | --spright_train_costum $WEBDATA_TRAIN \
86 | --spright_val_costum $WEBDATA_VAL
87 | ```
88 | To train the text encoder, set `--train_text_encoder`. The point at which text encoder training begins is determined by `--freeze_text_encoder_steps`, where 0 indicates that training for both the U-Net and text encoder starts simultaneously at the outset. It's possible to set different learning rates for the text encoder and the U-Net; these are configured through `--text_encoder_lr` for the text encoder and `--learning_rate` for the U-Net, respectively.
89 |
90 | ### Multiple Nodes
91 | In order to train on multiple nodes using SLURM, please refer to the [`spright_t2i_multinode_example.sh`](./spright_t2i_multinode_example.sh).
92 |
93 | ### Good to know
94 |
95 | Our training script supports experimentation tracking with Weights and Biases. If you wish to do so pass `--report="wandb"` in your training command. Make sure you install `wandb` before that.
96 |
97 | If you're on CUDA, you can push the training artifacts stored under `output_dir` to the Hugging Face Hub. Pass `--push_to_hub` if you wish to do so. You'd need to run `huggingface-cli login` before that.
98 |
--------------------------------------------------------------------------------
/training/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers
2 | transformers
3 | accelerate
4 | datasets
5 | webdataset
6 | torchvision
--------------------------------------------------------------------------------
/training/spright_t2i_multinode_example.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | #SBATCH -p PP
4 | #SBATCH --gres=gpu:8
5 | #SBATCH --ntasks-per-node=1
6 | #SBATCH --cpus-per-task=80
7 | #SBATCH -N 4
8 | #SBATCH --job-name=spatial_finetuning_stable_diffusion
9 |
10 |
11 | conda activate env_name
12 | cd /path/to/training/script
13 |
14 | export MODEL_NAME="SPRIGHT-T2I/spright-t2i-sd2"
15 | export OUTDIR="/path/to/output/dir"
16 | export SPRIGHT_SPLIT="path/to/spright/metadata.json"
17 |
18 | ACCELERATE_CONFIG_FILE="$OUTDIR/${SLURM_JOB_ID}_accelerate_config.yaml.autogenerated"
19 |
20 |
21 | GPUS_PER_NODE=8
22 | NNODES=$SLURM_NNODES
23 | NUM_GPUS=$((GPUS_PER_NODE*SLURM_NNODES))
24 |
25 | MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
26 | MASTER_PORT=25215
27 |
28 | # Auto-generate the accelerate config
29 | cat << EOT > $ACCELERATE_CONFIG_FILE
30 | compute_environment: LOCAL_MACHINE
31 | deepspeed_config: {}
32 | distributed_type: MULTI_GPU
33 | fsdp_config: {}
34 | machine_rank: 0
35 | main_process_ip: $MASTER_ADDR
36 | main_process_port: $MASTER_PORT
37 | main_training_function: main
38 | num_machines: $SLURM_NNODES
39 | num_processes: $NUM_GPUS
40 | use_cpu: false
41 | EOT
42 |
43 | # accelerate settings
44 | # Note: it is important to escape `$SLURM_PROCID` since we want the srun on each node to evaluate this variable
45 | export LAUNCHER="accelerate launch \
46 | --rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT,max_restarts=0,tee=3" \
47 | --config_file $ACCELERATE_CONFIG_FILE \
48 | --main_process_ip $MASTER_ADDR \
49 | --main_process_port $MASTER_PORT \
50 | --num_processes $NUM_GPUS \
51 | --machine_rank \$SLURM_PROCID \
52 | "
53 |
54 | # train
55 | PROGRAM="train.py \
56 | --pretrained_model_name_or_path=$MODEL_NAME \
57 | --use_ema \
58 | --seed 42 \
59 | --mixed_precision="fp16" \
60 | --resolution=768 --center_crop --random_flip \
61 | --train_batch_size=4 \
62 | --gradient_accumulation_steps=1 \
63 | --max_train_steps=15000 \
64 | --learning_rate=5e-06 \
65 | --max_grad_norm=1 \
66 | --lr_scheduler="constant" \
67 | --lr_warmup_steps=0 \
68 | --output_dir=$OUTDIR \
69 | --train_metadata_dir=$TRAIN_METADIR \
70 | --dataloader=$DATA_LOADER \
71 | --checkpointing_steps=1500 \
72 | --freeze_text_encoder_steps=0 \
73 | --train_text_encoder \
74 | --text_encoder_lr=1e-06 \
75 | --spright_splits $SPRIGHT_SPLIT
76 | "
77 |
78 |
79 | # srun error handling:
80 | # --wait=60: wait 60 sec after the first task terminates before terminating all remaining tasks
81 | # --kill-on-bad-exit=1: terminate a step if any task exits with a non-zero exit code
82 | SRUN_ARGS=" \
83 | --wait=60 \
84 | --kill-on-bad-exit=1 \
85 | "
86 |
87 | export CMD="$LAUNCHER $PROGRAM"
88 | echo $CMD
89 |
90 | srun $SRUN_ARGS --jobid $SLURM_JOB_ID bash -c "$CMD"
91 |
92 |
--------------------------------------------------------------------------------
/training/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2024 SPRIGHT authors and The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 |
16 | import argparse
17 |
18 | # added for gaudi
19 | import json
20 | import logging
21 | import math
22 | import os
23 | import random
24 | import shutil
25 | import time
26 | from pathlib import Path
27 |
28 | import accelerate
29 | import datasets
30 | import numpy as np
31 | import torch
32 | import torch.nn.functional as F
33 | import torch.utils.checkpoint
34 | import transformers
35 | from accelerate import Accelerator
36 |
37 |
38 | try:
39 | from optimum.habana import GaudiConfig
40 | from optimum.habana.accelerate import GaudiAccelerator
41 | except:
42 | GaudiConfig = None
43 | GaudiAccelerator = None
44 |
45 | from accelerate.logging import get_logger
46 | from accelerate.state import AcceleratorState
47 | from accelerate.utils import ProjectConfiguration
48 |
49 |
50 | try:
51 | from optimum.habana.utils import set_seed
52 | except:
53 | from accelerate.utils import set_seed
54 | import datetime
55 |
56 | from datasets import DownloadMode, load_dataset
57 | from huggingface_hub import create_repo, upload_folder
58 | from packaging import version
59 | from torchvision import transforms
60 | from tqdm.auto import tqdm
61 | from transformers import CLIPTextModel, CLIPTokenizer
62 | from transformers.utils import ContextManagers
63 |
64 | import diffusers
65 | from diffusers import AutoencoderKL, UNet2DConditionModel
66 |
67 |
68 | try:
69 | from optimum.habana.diffusers import GaudiDDIMScheduler, GaudiStableDiffusionPipeline
70 | except:
71 | from diffusers import DDPMScheduler, StableDiffusionPipeline
72 | from diffusers.optimization import get_scheduler
73 | from diffusers.training_utils import EMAModel, compute_snr
74 | from diffusers.utils import deprecate, is_wandb_available, make_image_grid
75 |
76 |
77 | try:
78 | # memory stats
79 | import habana_frameworks.torch as htorch
80 | import habana_frameworks.torch.core as htcore
81 | import habana_frameworks.torch.hpu as hthpu
82 | except:
83 | from diffusers.utils.import_utils import is_xformers_available
84 | htorch = None
85 | hthpu = None
86 | htcore = None
87 | import sys
88 |
89 |
90 | sys.path.append(os.path.dirname(os.getcwd()))
91 | import itertools
92 | import warnings
93 |
94 | import webdataset as wds
95 | from transformers import PretrainedConfig
96 |
97 | from diffusers.utils.torch_utils import is_compiled_module
98 |
99 |
100 | if is_wandb_available():
101 | import wandb
102 |
103 | debug = False
104 |
105 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
106 | #check_min_version("0.23.0.dev0")
107 |
108 | logger = get_logger(__name__, log_level="INFO")
109 |
110 |
111 | def save_model_card(
112 | args,
113 | repo_id=None,
114 | images=None,
115 | train_text_encoder=False,
116 | repo_folder=None,
117 | ):
118 | img_str = ""
119 | if len(images) > 0:
120 | image_grid = make_image_grid(images, 1, len(args.validation_prompts))
121 | image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
122 | img_str += "\n"
123 |
124 | yaml = f"""
125 | ---
126 | license: creativeml-openrail-m
127 | base_model: {args.pretrained_model_name_or_path}
128 | datasets:
129 | - {args.dataset_name}
130 | tags:
131 | - stable-diffusion
132 | - stable-diffusion-diffusers
133 | - text-to-image
134 | - diffusers
135 | inference: true
136 | ---
137 | """
138 | model_card = f"""
139 | # Text-to-image finetuning - {repo_id}
140 | Fine-tuning for the text encoder was enabled: {train_text_encoder}.
141 |
142 | This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
143 | {img_str}
144 |
145 | ## Pipeline usage
146 |
147 | You can use the pipeline like so:
148 |
149 | ```python
150 | from diffusers import DiffusionPipeline
151 | import torch
152 |
153 | pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
154 | prompt = "{args.validation_prompts[0]}"
155 | image = pipeline(prompt).images[0]
156 | image.save("my_image.png")
157 | ```
158 |
159 | ## Training info
160 |
161 | These are the key hyperparameters used during training:
162 |
163 | * Epochs: {args.num_train_epochs}
164 | * Learning rate: {args.learning_rate}
165 | * Batch size: {args.train_batch_size}
166 | * Gradient accumulation steps: {args.gradient_accumulation_steps}
167 | * Image resolution: {args.resolution}
168 | * Mixed-precision: {args.mixed_precision}
169 |
170 | """
171 | wandb_info = ""
172 | if is_wandb_available():
173 | wandb_run_url = None
174 | if wandb.run is not None:
175 | wandb_run_url = wandb.run.url
176 |
177 | if wandb_run_url is not None:
178 | wandb_info = f"""
179 | More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
180 | """
181 |
182 | model_card += wandb_info
183 |
184 | with open(os.path.join(repo_folder, "README.md"), "w") as f:
185 | f.write(yaml + model_card)
186 |
187 | def compute_validation_loss(val_dataloader, vae, text_encoder, noise_scheduler, unet, args, weight_dtype):
188 | val_loss = 0
189 | num_steps= math.ceil(len(val_dataloader))
190 | progress_bar = tqdm(
191 | range(0, num_steps),
192 | initial=0,
193 | desc="Steps",
194 | # Only show the progress bar once on each machine.
195 | disable=True,
196 | )
197 | for step, batch in enumerate(val_dataloader):
198 | progress_bar.update(1)
199 | # Convert images to latent space
200 | latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
201 | latents = latents * vae.config.scaling_factor
202 |
203 | # Sample noise that we'll add to the latents
204 | noise = torch.randn_like(latents)
205 | if args.noise_offset:
206 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise
207 | noise += args.noise_offset * torch.randn(
208 | (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
209 | )
210 | if args.input_perturbation:
211 | new_noise = noise + args.input_perturbation * torch.randn_like(noise)
212 | bsz = latents.shape[0]
213 | # Sample a random timestep for each image
214 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
215 | timesteps = timesteps.long()
216 |
217 | # Add noise to the latents according to the noise magnitude at each timestep
218 | # (this is the forward diffusion process)
219 | if args.input_perturbation:
220 | noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
221 | else:
222 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
223 |
224 | # Get the text embedding for conditioning
225 | encoder_hidden_states = text_encoder(batch["input_ids"])[0]
226 |
227 | # Get the target for loss depending on the prediction type
228 | if args.prediction_type is not None:
229 | # set prediction_type of scheduler if defined
230 | noise_scheduler.register_to_config(prediction_type=args.prediction_type)
231 |
232 | if noise_scheduler.config.prediction_type == "epsilon":
233 | target = noise
234 | elif noise_scheduler.config.prediction_type == "v_prediction":
235 | target = noise_scheduler.get_velocity(latents, noise, timesteps)
236 | else:
237 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
238 |
239 | # Predict the noise residual and compute loss
240 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
241 |
242 | if args.snr_gamma is None:
243 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
244 | else:
245 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
246 | # Since we predict the noise instead of x_0, the original formulation is slightly changed.
247 | # This is discussed in Section 4.2 of the same paper.
248 | snr = compute_snr(noise_scheduler, timesteps)
249 | if noise_scheduler.config.prediction_type == "v_prediction":
250 | # Velocity objective requires that we add one to SNR values before we divide by them.
251 | snr = snr + 1
252 | mse_loss_weights = (
253 | torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
254 | )
255 |
256 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
257 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
258 | loss = loss.mean()
259 |
260 | logs = {"step_val_loss": loss.detach().item()}
261 | progress_bar.set_postfix(**logs)
262 | val_loss += loss.item()
263 | val_loss /= (step+1)
264 | return val_loss
265 |
266 | def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch, val_dataloader, noise_scheduler):
267 | logger.info("Running validation... ")
268 |
269 |
270 | if args.validation_prompts is not None:
271 |
272 | if args.device == "hpu":
273 | pipeline = GaudiStableDiffusionPipeline.from_pretrained(
274 | args.pretrained_model_name_or_path,
275 | text_encoder=accelerator.unwrap_model(text_encoder),
276 | tokenizer=tokenizer,
277 | vae=accelerator.unwrap_model(vae),
278 | unet=accelerator.unwrap_model(unet),
279 | safety_checker=None,
280 | revision=args.revision,
281 | use_habana=True,
282 | use_hpu_graphs=True,
283 | gaudi_config=args.gaudi_config_name,
284 | )
285 | else:
286 | pipeline = StableDiffusionPipeline.from_pretrained(
287 | args.pretrained_model_name_or_path,
288 | vae=accelerator.unwrap_model(vae),
289 | text_encoder=accelerator.unwrap_model(text_encoder),
290 | tokenizer=tokenizer,
291 | unet=accelerator.unwrap_model(unet),
292 | safety_checker=None,
293 | revision=args.revision,
294 | torch_dtype=weight_dtype,
295 | )
296 | pipeline = pipeline.to(accelerator.device)
297 | pipeline.set_progress_bar_config(disable=True)
298 |
299 | if args.enable_xformers_memory_efficient_attention:
300 | pipeline.enable_xformers_memory_efficient_attention()
301 |
302 | if args.seed is None:
303 | generator = None
304 | else:
305 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
306 |
307 | images = []
308 | for i in range(len(args.validation_prompts)):
309 | if args.device == "hpu":
310 | image = pipeline(args.validation_prompts[i], num_inference_steps=50, generator=generator).images[0]
311 | else:
312 | with torch.autocast("cuda"):
313 | image = pipeline(args.validation_prompts[i], num_inference_steps=50, generator=generator).images[0]
314 |
315 | images.append(image)
316 |
317 | for tracker in accelerator.trackers:
318 | if tracker.name == "tensorboard":
319 | if args.validation_prompts is not None:
320 | np_images = np.stack([np.asarray(img) for img in images])
321 | tracker.writer.add_images("validation/images", np_images, epoch, dataformats="NHWC")
322 | elif tracker.name == "wandb":
323 | if args.validation_prompts is not None:
324 | tracker.log(
325 | {
326 | "validation/images": [
327 | wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
328 | for i, image in enumerate(images)
329 | ]
330 | }
331 | )
332 |
333 | else:
334 | if args.device == "hpu":
335 | logger.warning(f"image logging not implemented for {tracker.name}")
336 | else:
337 | logger.warn(f"image logging not implemented for {tracker.name}")
338 |
339 | del pipeline
340 | if args.device != "hpu":
341 | torch.cuda.empty_cache()
342 |
343 | return images
344 |
345 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
346 | text_encoder_config = PretrainedConfig.from_pretrained(
347 | pretrained_model_name_or_path,
348 | subfolder="text_encoder",
349 | revision=revision,
350 | )
351 | model_class = text_encoder_config.architectures[0]
352 |
353 | if model_class == "CLIPTextModel":
354 | from transformers import CLIPTextModel
355 |
356 | return CLIPTextModel
357 | elif model_class == "RobertaSeriesModelWithTransformation":
358 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
359 |
360 | return RobertaSeriesModelWithTransformation
361 | elif model_class == "T5EncoderModel":
362 | from transformers import T5EncoderModel
363 |
364 | return T5EncoderModel
365 | else:
366 | raise ValueError(f"{model_class} is not supported.")
367 |
368 | def parse_args():
369 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
370 | parser.add_argument(
371 | "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
372 | )
373 | parser.add_argument(
374 | "--pretrained_model_name_or_path",
375 | type=str,
376 | default=None,
377 | required=True,
378 | help="Path to pretrained model or model identifier from huggingface.co/models.",
379 | )
380 | parser.add_argument(
381 | "--revision",
382 | type=str,
383 | default=None,
384 | required=False,
385 | help="Revision of pretrained model identifier from huggingface.co/models.",
386 | )
387 | parser.add_argument(
388 | "--dataset_name",
389 | type=str,
390 | default=None,
391 | help=(
392 | "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
393 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
394 | " or to a folder containing files that 🤗 Datasets can understand."
395 | ),
396 | )
397 | parser.add_argument(
398 | "--dataset_config_name",
399 | type=str,
400 | default=None,
401 | help="The config of the Dataset, leave as None if there's only one config.",
402 | )
403 | parser.add_argument(
404 | "--train_data_dir",
405 | type=str,
406 | default=None,
407 | help=(
408 | "A folder containing the training data. Folder contents must follow the structure described in"
409 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
410 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
411 | ),
412 | )
413 | parser.add_argument(
414 | "--spright_splits",
415 | type=str,
416 | default="split.json",
417 | help=(
418 | "A url containing the json file that defines the splits (https://huggingface.co/datasets/ASU-HF/spright/blob/main/split.json). The webdataset should contain the metadata as tar files."
419 | ),
420 | )
421 | parser.add_argument(
422 | "--spright_train_costum",
423 | type=str,
424 | default=None,
425 | help=(
426 | "A url containing the webdataset train split. The webdataset should contain the metadata as tar files."
427 | ),
428 | )
429 | parser.add_argument(
430 | "--spright_val_costum",
431 | type=str,
432 | default=None,
433 | help=(
434 | "A url containing the webdataset validation split. The webdataset should contain the metadata as tar files."
435 | ),
436 | )
437 | parser.add_argument(
438 | "--webdataset_buffer_size",
439 | type=int,
440 | default=1000,
441 | help=(
442 | "buffer size of webdataset."
443 | ),
444 | )
445 | parser.add_argument(
446 | "--dataset_size",
447 | type=float,
448 | default=None,
449 | help="dataset size to use. If set, the dataset will be truncated to this size.",
450 | )
451 | parser.add_argument(
452 | "--val_split",
453 | type=float,
454 | default=0.1,
455 | help="ratio of validation size out of the entire dataset "
456 | )
457 | parser.add_argument(
458 | "--train_metadata_dir",
459 | type=str,
460 | default=None,
461 | help=(
462 | "A folder containing subfolders: train, val, test with the metadata as jsonl files."
463 | " jsonl files provide the general and spatial captions for the images."
464 | ),
465 | )
466 | parser.add_argument(
467 | "--dataloader",
468 | type=str,
469 | default=None,
470 | help=(
471 | "A python script with custom dataloader."
472 | ),
473 | )
474 | parser.add_argument(
475 | "--image_column", type=str, default="image", help="The column of the dataset containing an image."
476 | )
477 | parser.add_argument(
478 | "--caption_column",
479 | type=str,
480 | default="text",
481 | help="The column of the dataset containing a caption or a list of captions.",
482 | )
483 | parser.add_argument(
484 | "--max_train_samples",
485 | type=int,
486 | default=None,
487 | help=(
488 | "For debugging purposes or quicker training, truncate the number of training examples to this "
489 | "value if set."
490 | ),
491 | )
492 | parser.add_argument(
493 | "--validation_prompts",
494 | type=str,
495 | default=["The city is located behind the water, and the pier is relatively small in comparison to the expanse of the water and the city", "The bed is positioned in the center of the frame, with two red pillows on the left side", "The houses are located on the left side of the street, while the park is on the right side", "The spoon is located on the left side of the shelf, while the bowl is positioned in the center", "The room has a red carpet, and there is a chandelier hanging from the ceiling above the bed"],
496 | nargs="+",
497 | help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
498 | )
499 | parser.add_argument(
500 | "--output_dir",
501 | type=str,
502 | default="",
503 | required=True,
504 | help="The output directory where the model predictions and checkpoints will be written.",
505 | )
506 | parser.add_argument(
507 | "--cache_dir",
508 | type=str,
509 | default=None,
510 | help="The directory where the downloaded models and datasets will be stored.",
511 | )
512 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
513 | parser.add_argument(
514 | "--resolution",
515 | type=int,
516 | default=512,
517 | help=(
518 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
519 | " resolution"
520 | ),
521 | )
522 | parser.add_argument(
523 | "--pre_crop_resolution",
524 | type=int,
525 | default=768,
526 | help=(
527 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
528 | " resolution before being randomly cropped to the final `resolution`."
529 | ),
530 | )
531 | parser.add_argument(
532 | "--center_crop",
533 | default=False,
534 | action="store_true",
535 | help=(
536 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
537 | " cropped. The images will be resized to the resolution first before cropping."
538 | ),
539 | )
540 | parser.add_argument(
541 | "--random_flip",
542 | action="store_true",
543 | help="whether to randomly flip images horizontally",
544 | )
545 | parser.add_argument(
546 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
547 | )
548 | parser.add_argument("--num_train_epochs", type=int, default=100)
549 | parser.add_argument(
550 | "--max_train_steps",
551 | type=int,
552 | default=None,
553 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
554 | )
555 | parser.add_argument(
556 | "--gradient_accumulation_steps",
557 | type=int,
558 | default=1,
559 | help="Number of updates steps to accumulate before performing a backward/update pass.",
560 | )
561 | parser.add_argument(
562 | "--gradient_checkpointing",
563 | action="store_true",
564 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
565 | )
566 | parser.add_argument(
567 | "--learning_rate",
568 | type=float,
569 | default=1e-4,
570 | help="Initial learning rate (after the potential warmup period) to use.",
571 | )
572 | parser.add_argument(
573 | "--text_encoder_lr",
574 | type=float,
575 | default=None,
576 | help="Initial learning rate for the text encoder - should usually be samller than unet_lr(after the potential warmup period) to use. When set to None, it will be set to the same value as learning_rate.",
577 | )
578 | parser.add_argument(
579 | "--scale_lr",
580 | action="store_true",
581 | default=False,
582 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
583 | )
584 | parser.add_argument(
585 | "--lr_scheduler",
586 | type=str,
587 | default="constant",
588 | help=(
589 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
590 | ' "constant", "constant_with_warmup"]'
591 | ),
592 | )
593 | parser.add_argument(
594 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
595 | )
596 | parser.add_argument(
597 | "--snr_gamma",
598 | type=float,
599 | default=None,
600 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
601 | "More details here: https://arxiv.org/abs/2303.09556.",
602 | )
603 | parser.add_argument(
604 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
605 | )
606 | parser.add_argument(
607 | "--allow_tf32",
608 | action="store_true",
609 | help=(
610 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
611 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
612 | ),
613 | )
614 | parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
615 | parser.add_argument(
616 | "--non_ema_revision",
617 | type=str,
618 | default=None,
619 | required=False,
620 | help=(
621 | "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
622 | " remote repository specified with --pretrained_model_name_or_path."
623 | ),
624 | )
625 | parser.add_argument(
626 | "--dataloader_num_workers",
627 | type=int,
628 | default=0,
629 | help=(
630 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
631 | ),
632 | )
633 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
634 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
635 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
636 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
637 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
638 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
639 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
640 | parser.add_argument(
641 | "--prediction_type",
642 | type=str,
643 | default=None,
644 | help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
645 | )
646 | parser.add_argument(
647 | "--hub_model_id",
648 | type=str,
649 | default=None,
650 | help="The name of the repository to keep in sync with the local `output_dir`.",
651 | )
652 | parser.add_argument(
653 | "--logging_dir",
654 | type=str,
655 | default="logs",
656 | help=(
657 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
658 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
659 | ),
660 | )
661 | parser.add_argument(
662 | "--mixed_precision",
663 | type=str,
664 | default=None,
665 | choices=["no", "fp16", "bf16"],
666 | help=(
667 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
668 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
669 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
670 | ),
671 | )
672 | parser.add_argument(
673 | "--report_to",
674 | type=str,
675 | default="tensorboard",
676 | help=(
677 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
678 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
679 | ),
680 | )
681 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
682 | parser.add_argument(
683 | "--checkpointing_steps",
684 | type=int,
685 | default=500,
686 | help=(
687 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
688 | " training using `--resume_from_checkpoint`."
689 | ),
690 | )
691 | parser.add_argument(
692 | "--checkpoints_total_limit",
693 | type=int,
694 | default=None,
695 | help=("Max number of checkpoints to store."),
696 | )
697 | parser.add_argument(
698 | "--resume_from_checkpoint",
699 | type=str,
700 | default=None,
701 | help=(
702 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
703 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
704 | ),
705 | )
706 | parser.add_argument(
707 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
708 | )
709 | parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
710 | parser.add_argument(
711 | "--validation_epochs",
712 | type=int,
713 | default=5,
714 | help="Run validation every X epochs.",
715 | )
716 | parser.add_argument(
717 | "--tracker_project_name",
718 | type=str,
719 | default="text2image-fine-tune",
720 | help=(
721 | "The `project_name` argument passed to Accelerator.init_trackers for"
722 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
723 | ),
724 | )
725 | parser.add_argument(
726 | "--gaudi_config_name",
727 | type=str,
728 | default=None,
729 | help="Local path to the Gaudi configuration file or its name on the Hugging Face Hub.",
730 | )
731 | parser.add_argument(
732 | "--throughput_warmup_steps",
733 | type=int,
734 | default=0,
735 | help=(
736 | "Number of steps to ignore for throughput calculation. For example, with throughput_warmup_steps=N, the"
737 | " first N steps will not be considered in the calculation of the throughput. This is especially useful in"
738 | " lazy mode."
739 | ),
740 | )
741 | parser.add_argument(
742 | "--bf16",
743 | action="store_true",
744 | default=False,
745 | help=("Whether to use bf16 mixed precision."),
746 | )
747 | parser.add_argument(
748 | "--device",
749 | type=str,
750 | default=None,
751 | help=("hpu, cuda or cpu."),
752 | )
753 | parser.add_argument(
754 | "--train_text_encoder",
755 | action="store_true",
756 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
757 | )
758 | parser.add_argument(
759 | "--tokenizer_name",
760 | type=str,
761 | default=None,
762 | help="Pretrained tokenizer name or path if not the same as model_name",
763 | )
764 | parser.add_argument(
765 | "--freeze_text_encoder_steps",
766 | type=int,
767 | default=0,
768 | help="Start text_encoder training after freeze_text_encoder_steps steps.",
769 | )
770 | parser.add_argument("--comment", type=str, default="used long sentences generated by llava - spatial and general.", help="Comment that should appear in the run config")
771 | parser.add_argument("--git_token", type=str, default=None, help="If provided will enable to save the git sha to replicate")
772 | parser.add_argument("--general_caption", type=str, default="original_caption", choices = ["coca_caption", "original_caption"],
773 | help="Original are the oned from the original dataset, coca_caption is the one generated by COCA" \
774 | "in case original is chosen, the original caption will be preffered as general_caption, if it does not exist than the general caption will be the coca caption")
775 | parser.add_argument("--spatial_caption_type", type=str, default="long", choices = ["short", "long", "short_negative"], help="Wheter to use long or short spatial captions")
776 | parser.add_argument(
777 | "--spatial_percent",
778 | type=float,
779 | default=50.0,
780 | help="approximately precentage of the time that spatial captions is chosen.",
781 | )
782 |
783 |
784 | args = parser.parse_args()
785 | if args.resume_from_checkpoint is None:
786 | args.output_dir = os.path.join(args.output_dir , f"run_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")
787 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
788 | if env_local_rank != -1 and env_local_rank != args.local_rank:
789 | args.local_rank = env_local_rank
790 |
791 | # default to using the same revision for the non-ema model if not specified
792 | if args.non_ema_revision is None:
793 | args.non_ema_revision = args.revision
794 |
795 | return args
796 |
797 |
798 | def main():
799 | args = parse_args()
800 |
801 | url_train = None
802 |
803 | if args.spright_splits is not None and args.spright_train_costum is not None:
804 | warnings.warn("You can not specify the splits by both spright_splits and spright_train_costum." \
805 | "The costum split will be used. If you want to use the SPRIGHT splits, remove the spright_train_costum argument.")
806 |
807 |
808 | if args.report_to == "wandb" and args.hub_token is not None:
809 | raise ValueError(
810 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
811 | " Please use `huggingface-cli login` to authenticate with the Hub."
812 | )
813 |
814 | # set device
815 | if hthpu and hthpu.is_available():
816 | args.device = "hpu"
817 | logger.info("Using HPU")
818 | elif torch.cuda.is_available():
819 | logger.info.device = "cuda"
820 | print("Using GPU")
821 | else:
822 | args.device = "cpu"
823 | logger.info("Using CPU")
824 |
825 | # set precision:
826 | if args.device == "hpu":
827 | if args.mixed_precision == "bf16":
828 | args.bf16 = True
829 | else:
830 | args.bf16 = False
831 |
832 | # set args for gaudi:
833 | assert not args.enable_xformers_memory_efficient_attention, "xformers is not supported on gaudi"
834 | assert not args.allow_tf32, "tf32 is not supported on gaudi"
835 | assert not args.gradient_checkpointing, "gradient_checkpointing is not supported on gaudi locally"
836 | assert not args.push_to_hub, "push_to_hub is not supported on gaudi locally"
837 |
838 | else:
839 | assert args.gaudi_config_name is None, "gaudi_config_name is only supported on gaudi"
840 | assert args.throughput_warmup_steps == 0, "throughput_warmup_steps is only supported on gaudi"
841 |
842 |
843 | if args.non_ema_revision is not None:
844 | deprecate(
845 | "non_ema_revision!=None",
846 | "0.15.0",
847 | message=(
848 | "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
849 | " use `--variant=non_ema` instead."
850 | ),
851 | )
852 | logging_dir = os.path.join(args.output_dir, args.logging_dir)
853 |
854 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
855 |
856 | if args.device == "hpu":
857 | gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name)
858 | if args.use_8bit_adam:
859 | gaudi_config.use_fused_adam = True
860 | args.use_8bit_adam = False
861 |
862 | accelerator = GaudiAccelerator(
863 | gradient_accumulation_steps=args.gradient_accumulation_steps,
864 | mixed_precision="bf16" if gaudi_config.use_torch_autocast or args.bf16 else "no",
865 | log_with=args.report_to,
866 | project_config=accelerator_project_config,
867 | force_autocast=gaudi_config.use_torch_autocast or args.bf16,
868 | )
869 | else:
870 | accelerator = Accelerator(
871 | gradient_accumulation_steps=args.gradient_accumulation_steps,
872 | mixed_precision=args.mixed_precision,
873 | log_with=args.report_to,
874 | project_config=accelerator_project_config,
875 | )
876 |
877 | # Make one log on every process with the configuration for debugging.
878 | logging.basicConfig(
879 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
880 | datefmt="%m/%d/%Y %H:%M:%S",
881 | level=logging.INFO,
882 | )
883 | logger.info(accelerator.state, main_process_only=False)
884 | if accelerator.is_local_main_process:
885 | datasets.utils.logging.set_verbosity_warning()
886 | transformers.utils.logging.set_verbosity_warning()
887 | diffusers.utils.logging.set_verbosity_info()
888 | else:
889 | datasets.utils.logging.set_verbosity_error()
890 | transformers.utils.logging.set_verbosity_error()
891 | diffusers.utils.logging.set_verbosity_error()
892 |
893 | # If passed along, set the training seed now.
894 | if args.seed is not None:
895 | set_seed(args.seed)
896 |
897 | # Handle the repository creation
898 | if accelerator.is_main_process:
899 | if args.output_dir is not None:
900 | os.makedirs(args.output_dir, exist_ok=True)
901 |
902 | if args.push_to_hub:
903 | repo_id = create_repo(
904 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
905 | ).repo_id
906 |
907 | # Load scheduler, tokenizer and models.
908 | if args.device == "hpu":
909 | noise_scheduler = GaudiDDIMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
910 | else:
911 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
912 |
913 | tokenizer = CLIPTokenizer.from_pretrained(
914 | args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
915 | )
916 |
917 | def deepspeed_zero_init_disabled_context_manager():
918 | """
919 | returns either a context list that includes one that will disable zero.Init or an empty context list
920 | """
921 | deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
922 | if deepspeed_plugin is None:
923 | return []
924 |
925 | return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
926 |
927 | # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
928 | # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
929 | # will try to assign the same optimizer with the same weights to all models during
930 | # `deepspeed.initialize`, which of course doesn't work.
931 | #
932 | # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
933 | # frozen models from being partitioned during `zero.Init` which gets called during
934 | # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
935 | # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
936 | if not args.train_text_encoder:
937 | with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
938 | text_encoder = CLIPTextModel.from_pretrained(
939 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
940 | ).to(accelerator.device)
941 | vae = AutoencoderKL.from_pretrained(
942 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
943 | )
944 | else:
945 | # import correct text encoder class
946 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
947 | text_encoder = text_encoder_cls.from_pretrained(
948 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision)
949 | vae = AutoencoderKL.from_pretrained(
950 | args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
951 | )
952 |
953 |
954 | unet = UNet2DConditionModel.from_pretrained(
955 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
956 | )
957 |
958 | # Freeze vae and text_encoder and set unet to trainable
959 | vae.requires_grad_(False)
960 | if not args.train_text_encoder:
961 | text_encoder.requires_grad_(False)
962 | unet.train()
963 |
964 | # Create EMA for the unet.
965 | if args.use_ema:
966 | ema_unet = UNet2DConditionModel.from_pretrained(
967 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
968 | )
969 | ema_unet = EMAModel(ema_unet.parameters(), model_cls=UNet2DConditionModel, model_config=ema_unet.config)
970 |
971 | if args.enable_xformers_memory_efficient_attention:
972 | if is_xformers_available():
973 | import xformers
974 |
975 | xformers_version = version.parse(xformers.__version__)
976 | if xformers_version == version.parse("0.0.16"):
977 | logger.warn(
978 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
979 | )
980 | unet.enable_xformers_memory_efficient_attention()
981 | else:
982 | raise ValueError("xformers is not available. Make sure it is installed correctly")
983 |
984 | def unwrap_model(model):
985 | model = accelerator.unwrap_model(model)
986 | model = model._orig_mod if is_compiled_module(model) else model
987 | return model
988 |
989 | # `accelerate` 0.16.0 will have better support for customized saving
990 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
991 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
992 | def save_model_hook(models, weights, output_dir):
993 | if accelerator.is_main_process:
994 | if args.use_ema:
995 | ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
996 |
997 | for model in models:
998 | sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder"
999 | model.save_pretrained(os.path.join(output_dir, sub_dir))
1000 |
1001 | # make sure to pop weight so that corresponding model is not saved again
1002 | weights.pop()
1003 |
1004 | def load_model_hook(models, input_dir):
1005 | if args.use_ema:
1006 | load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel)
1007 | ema_unet.load_state_dict(load_model.state_dict())
1008 | ema_unet.to(accelerator.device)
1009 | del load_model
1010 |
1011 | for i in range(len(models)):
1012 | # pop models so that they are not loaded again
1013 | model = models.pop()
1014 |
1015 | if isinstance(model, type(unwrap_model(text_encoder))):
1016 | # load transformers style into model
1017 | load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
1018 | model.config = load_model.config
1019 | else:
1020 | # load diffusers style into model
1021 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
1022 | model.register_to_config(**load_model.config)
1023 |
1024 | model.load_state_dict(load_model.state_dict())
1025 | del load_model
1026 |
1027 | accelerator.register_save_state_pre_hook(save_model_hook)
1028 | accelerator.register_load_state_pre_hook(load_model_hook)
1029 |
1030 | if args.gradient_checkpointing:
1031 | unet.enable_gradient_checkpointing()
1032 | if args.train_text_encoder:
1033 | text_encoder.gradient_checkpointing_enable()
1034 |
1035 | # Check that all trainable models are in full precision
1036 | low_precision_error_string = (
1037 | "Please make sure to always have all model weights in full float32 precision when starting training - even if"
1038 | " doing mixed precision training. copy of the weights should still be float32."
1039 | )
1040 |
1041 | if unwrap_model(unet).dtype != torch.float32:
1042 | raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}")
1043 |
1044 | if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
1045 | raise ValueError(
1046 | f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
1047 | )
1048 |
1049 | # Enable TF32 for faster training on Ampere GPUs,
1050 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1051 | if args.allow_tf32:
1052 | torch.backends.cuda.matmul.allow_tf32 = True
1053 |
1054 | if args.scale_lr:
1055 | args.learning_rate = (
1056 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1057 | )
1058 |
1059 | # Initialize the optimizer
1060 | if args.use_8bit_adam:
1061 | try:
1062 | import bitsandbytes as bnb
1063 | except ImportError:
1064 | raise ImportError(
1065 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
1066 | )
1067 |
1068 | optimizer_cls = bnb.optim.AdamW8bit
1069 | elif args.device == "hpu" and gaudi_config.use_fused_adam:
1070 | from habana_frameworks.torch.hpex.optimizers import FusedAdamW
1071 |
1072 | optimizer_cls = FusedAdamW
1073 | else:
1074 | optimizer_cls = torch.optim.AdamW
1075 | # setting diffetent lr for text ancoder and unet:
1076 | if args.train_text_encoder:
1077 | unet_parameters_with_lr = {"params": unet.parameters(), "lr": args.learning_rate}
1078 | text_encoder_params_with_lr = {
1079 | "params": text_encoder.parameters(),
1080 | "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate
1081 | }
1082 | params_to_optimize = [
1083 | unet_parameters_with_lr,
1084 | text_encoder_params_with_lr
1085 | ]
1086 | print(f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
1087 | f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. ")
1088 |
1089 |
1090 | else:
1091 | params_to_optimize = unet_parameters_with_lr
1092 | optimizer = optimizer_cls(
1093 | params_to_optimize,
1094 | betas=(args.adam_beta1, args.adam_beta2),
1095 | weight_decay=args.adam_weight_decay,
1096 | eps=args.adam_epsilon,
1097 | )
1098 |
1099 |
1100 |
1101 | # Get the datasets: you can either provide your own training and evaluation files (see below)
1102 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
1103 |
1104 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently
1105 | # download the dataset.
1106 | if args.dataset_name is not None:
1107 | # Downloading and loading a dataset from the hub.
1108 | dataset = load_dataset(
1109 | args.dataset_name,
1110 | args.dataset_config_name,
1111 | cache_dir=args.cache_dir,
1112 | data_dir=args.train_data_dir,
1113 | download_mode=DownloadMode.FORCE_REDOWNLOAD
1114 | )
1115 | elif args.train_data_dir is not None:
1116 | data_files = {}
1117 | data_files["train"] = os.path.join(args.train_data_dir, "**")
1118 | dataset = load_dataset(
1119 | "imagefolder",
1120 | data_files=data_files,
1121 | cache_dir=args.cache_dir,
1122 | )
1123 | # See more about loading custom images at
1124 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
1125 | elif args.dataloader is not None:
1126 | dataset = load_dataset(
1127 | args.dataloader,
1128 | data_dir=args.train_metadata_dir,
1129 | cache_dir=args.cache_dir
1130 | )
1131 | elif args.spright_splits is not None:
1132 | # load the jaon files and read the train and val splits
1133 | if args.spright_splits.endswith(".json"):
1134 | with open(args.spright_splits, 'r') as f:
1135 | data = json.load(f)
1136 | # Filter the entries where split is 'train'
1137 | train_files = []
1138 | train_len = 0
1139 | val_files = []
1140 | val_len = 0
1141 | test_files = []
1142 | test_len = 0
1143 | if args.spright_train_costum is not None:
1144 | url_train = args.spright_train_costum
1145 | start, end = map(int, url_train.split('/')[-1].strip('.tar{}').split('..'))
1146 | # Filter the data for the files in the range and sum the sizes
1147 | train_len = sum(item['size'] for item in data if start <= int(item['file'].strip('.tar')) <= end)
1148 | if args.spright_val_costum is not None:
1149 | url_val = args.spright_val_costum
1150 | start_val, end_val = map(int, url_train.split('/')[-1].strip('.tar{}').split('..'))
1151 | val_len = sum(item['size'] for item in data if start_val <= int(item['file'].strip('.tar')) <= end_val)
1152 | else:
1153 | url_val = None
1154 | else:
1155 | for item in data:
1156 | if item['split'] == 'train':
1157 | train_files.append(item['file'].split(".")[0])
1158 | train_len += item['size']
1159 | elif item['split'] == 'val':
1160 | val_files.append(item['file'].split(".")[0])
1161 | val_len += item['size']
1162 | elif item['split'] == 'test':
1163 | test_files.append(item['file'].split(".")[0])
1164 | test_len += item['size']
1165 | ext = data[0]['file'].split(".")[-1]
1166 | # Construct the url_train string
1167 | if len(train_files) != 0:
1168 | url_train = "/export/share/projects/mcai/spatial_data/spright/data/{" + ','.join(train_files) + "}" + f".{ext}"
1169 | if len(val_files) != 0:
1170 | url_val = "/export/share/projects/mcai/spatial_data/spright/data/{" + ','.join(val_files) + "}" + f".{ext}"
1171 | else:
1172 | url_val = None
1173 | else:
1174 | raise ValueError("'webdataset' should be a json file containing the train and val splits and there sizes")
1175 |
1176 | # Preprocessing the datasets.
1177 | # We need to tokenize inputs and targets.
1178 | # column_names = dataset["train"].column_names
1179 | # 6. Get the column names for input/target.
1180 | if args.spatial_caption_type == "short":
1181 | spatial_caption = 'short_spatial_caption'
1182 | elif args.spatial_caption_type == "long":
1183 | spatial_caption = 'spatial_caption'
1184 | elif args.spatial_caption_type == "short_negative":
1185 | spatial_caption = 'short_spatial_caption_negation'
1186 |
1187 | # Preprocessing the datasets.
1188 | # We need to tokenize input captions and transform the images.
1189 |
1190 | def get_random_sentence(caption):
1191 | if isinstance(caption, list):
1192 | caption = caption[0]
1193 | sentences_split = caption.split(".")
1194 | first_sentence = sentences_split[0]
1195 | sentences_split = sentences_split[1:]
1196 | if len(sentences_split)>1:
1197 | random_sentence = random.choice(sentences_split)
1198 | if len(random_sentence) == 0:
1199 | random_sentence = sentences_split[0]
1200 | else:
1201 | random_sentence = first_sentence
1202 | return random_sentence
1203 |
1204 |
1205 | def tokenize_captions(examples, is_train=True):
1206 | # check if its a webdataset
1207 | if url_train is not None:
1208 | examples = examples["captions"]
1209 |
1210 | column_lists = [args.general_caption, 'spatial_caption']
1211 |
1212 | caption_column = random.choices(
1213 | column_lists,
1214 | weights=[100-args.spatial_percent, args.spatial_percent],
1215 | k=1
1216 | )[0]
1217 |
1218 | # check if the caption exists
1219 | if args.general_caption=="original_caption" and caption_column==args.general_caption:
1220 | if caption_column not in examples.keys():
1221 | caption_column = "coca_caption"
1222 | elif args.spatial_caption_type == "short" and caption_column==spatial_caption:
1223 | examples['short_spatial_caption'] = get_random_sentence(examples[caption_column])
1224 | caption_column = 'short_spatial_caption'
1225 |
1226 | captions = []
1227 | # check if the caption is a list
1228 | if not isinstance(examples[caption_column], list):
1229 | examples[caption_column] = [examples[caption_column]]
1230 |
1231 | for caption in examples[caption_column]:
1232 | if isinstance(caption, str):
1233 | captions.append(caption)
1234 | if debug:
1235 | with open(os.path.join(args.output_dir, f"selected_captions{args.spatial_percent}.txt"), "a") as f:
1236 | f.write(f"{caption}\n")
1237 | elif isinstance(caption, (list, np.ndarray)):
1238 | # take a random caption if there are multiple
1239 | captions.append(random.choice(caption) if is_train else caption[0])
1240 | else:
1241 | raise ValueError(
1242 | f"Caption column `{caption_column}` should contain either strings or lists of strings."
1243 | )
1244 | inputs = tokenizer(
1245 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
1246 | )
1247 | return inputs.input_ids
1248 |
1249 | # Preprocessing the datasets.
1250 | if args.pre_crop_resolution and args.pre_crop_resolution > args.resolution:
1251 | train_transforms = transforms.Compose(
1252 | [
1253 | transforms.Resize(args.pre_crop_resolution, interpolation=transforms.InterpolationMode.BILINEAR),
1254 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
1255 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
1256 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
1257 | transforms.ToTensor(),
1258 | transforms.Normalize([0.5], [0.5]),
1259 | ]
1260 | )
1261 | else:
1262 | train_transforms = transforms.Compose(
1263 | [
1264 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
1265 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
1266 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
1267 | transforms.ToTensor(),
1268 | transforms.Normalize([0.5], [0.5]),
1269 | ]
1270 | )
1271 |
1272 | def preprocess_train(examples):
1273 | if url_train is not None:
1274 | examples["image"] = [examples["image"]]
1275 | images = [image.convert("RGB") for image in examples["image"]]
1276 | examples["pixel_values"] = [train_transforms(image) for image in images]
1277 | examples["input_ids"] = tokenize_captions(examples)
1278 | return examples
1279 |
1280 | def preprocess_train_back(examples):
1281 | images = [image.convert("RGB") for image in examples["image"]]
1282 | examples["pixel_values"] = [train_transforms(image) for image in images]
1283 | examples["input_ids"] = tokenize_captions(examples)
1284 | return examples
1285 |
1286 | def collate_fn(examples):
1287 | pixel_values = torch.stack([example["pixel_values"] for example in examples])
1288 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
1289 | input_ids = torch.stack([example["input_ids"] for example in examples])
1290 | return {"pixel_values": pixel_values, "input_ids": input_ids}
1291 |
1292 | def collate_fn_webdataset(examples):
1293 | examples = [preprocess_train(example) for example in examples]
1294 | pixel_values = torch.stack([example["pixel_values"][0] for example in examples])
1295 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
1296 | input_ids = torch.stack([example["input_ids"] for example in examples])
1297 | return {"pixel_values": pixel_values, "input_ids": input_ids}
1298 |
1299 | with accelerator.main_process_first():
1300 | if url_train is None:
1301 | if args.max_train_samples is not None:
1302 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
1303 | train_dataset = dataset["train"].with_transform(preprocess_train)
1304 | train_dataloader = torch.utils.data.DataLoader(
1305 | train_dataset,
1306 | shuffle=True,
1307 | collate_fn=collate_fn,
1308 | batch_size=args.train_batch_size,
1309 | num_workers=args.dataloader_num_workers,
1310 | )
1311 | if "validation" in dataset:
1312 | val_dataloader = torch.utils.data.DataLoader(
1313 | dataset["validation"].with_transform(preprocess_train),
1314 | shuffle=True,
1315 | collate_fn=collate_fn,
1316 | batch_size=args.train_batch_size,
1317 | num_workers=0,
1318 | )
1319 | else:
1320 | val_dataloader = None
1321 | else:
1322 | dataset = {"train": wds.WebDataset(url_train).shuffle(args.webdataset_buffer_size).decode("pil", handler=wds.warn_and_continue).rename(captions="json",image="jpg",metadata="metadata.json",handler=wds.warn_and_continue,)}
1323 | train_dataloader = torch.utils.data.DataLoader(dataset["train"], collate_fn=collate_fn_webdataset, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers)
1324 | if url_val:
1325 | dataset["val"] = wds.WebDataset(url_val).shuffle(args.webdataset_buffer_size).decode("pil", handler=wds.warn_and_continue). rename(captions="json",image="jpg",metadata="metadata.json",handler=wds.warn_and_continue,)
1326 | val_dataloader = torch.utils.data.DataLoader(dataset["val"], collate_fn=collate_fn_webdataset, batch_size=args.train_batch_size, num_workers=0)
1327 | else:
1328 | val_dataloader = None
1329 | if args.max_train_samples is not None:
1330 | dataset = (
1331 | wds.WebDataset(url_train, shardshuffle=True)
1332 | .shuffle(args.webdataset_buffer_size)
1333 | .decode()
1334 | )
1335 | # Set the training transforms
1336 | train_dataset = dataset["train"]
1337 |
1338 | # Scheduler and math around the number of training steps.
1339 | overrode_max_train_steps = False
1340 | if url_train is None:
1341 | train_len = len(train_dataset)
1342 | train_dataloader_len = len(train_dataloader)
1343 | assert len(train_dataloader) == len(train_dataset) / args.train_batch_size, (len(train_dataloader), len(train_dataset))
1344 | else:
1345 | train_dataloader_len = train_len / args.train_batch_size
1346 |
1347 | num_update_steps_per_epoch = math.ceil(train_dataloader_len / args.gradient_accumulation_steps)
1348 |
1349 | if args.max_train_steps is None:
1350 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1351 | overrode_max_train_steps = True
1352 |
1353 | lr_scheduler = get_scheduler(
1354 | args.lr_scheduler,
1355 | optimizer=optimizer,
1356 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1357 | num_training_steps=args.max_train_steps * accelerator.num_processes,
1358 | )
1359 |
1360 | if not args.train_text_encoder:
1361 | unet.to(accelerator.device)
1362 |
1363 | # Prepare everything with our `accelerator`.
1364 | if val_dataloader is not None:
1365 | if args.train_text_encoder:
1366 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
1367 | unet, text_encoder, optimizer, train_dataloader,val_dataloader, lr_scheduler
1368 | )
1369 | else:
1370 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
1371 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
1372 | )
1373 | else:
1374 | if args.train_text_encoder:
1375 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1376 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler
1377 | )
1378 | else:
1379 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1380 | unet, optimizer, train_dataloader, lr_scheduler
1381 | )
1382 |
1383 | if args.use_ema:
1384 | ema_unet.to(accelerator.device)
1385 |
1386 | # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
1387 | # as these weights are only used for inference, keeping weights in full precision is not required.
1388 | weight_dtype = torch.float32
1389 | if args.device != "hpu" and accelerator.mixed_precision == "fp16":
1390 | weight_dtype = torch.float16
1391 | args.mixed_precision = accelerator.mixed_precision
1392 | elif accelerator.mixed_precision == "bf16":
1393 | weight_dtype = torch.bfloat16
1394 | args.mixed_precision = accelerator.mixed_precision
1395 | elif args.device == "hpu" and gaudi_config.use_torch_autocast or args.bf16:
1396 | weight_dtype = torch.bfloat16
1397 |
1398 | # Move text_encode and vae to gpu and cast to weight_dtype
1399 | if not args.train_text_encoder and text_encoder is not None:
1400 | text_encoder.to(accelerator.device, dtype=weight_dtype)
1401 | vae.to(accelerator.device, dtype=weight_dtype)
1402 |
1403 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1404 | num_update_steps_per_epoch = math.ceil(train_dataloader_len / args.gradient_accumulation_steps)
1405 | if overrode_max_train_steps:
1406 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1407 | # Afterwards we recalculate our number of training epochs
1408 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1409 |
1410 | # We need to initialize the trackers we use, and also store our configuration.
1411 | # The trackers initializes automatically on the main process.
1412 | if accelerator.is_main_process:
1413 | tracker_config = dict(vars(args))
1414 | tracker_config.pop("validation_prompts")
1415 | accelerator.init_trackers(args.tracker_project_name, tracker_config)
1416 |
1417 | # Train!
1418 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1419 |
1420 | logger.info("***** Running training *****")
1421 | logger.info(f" Num examples = {train_len}")
1422 | logger.info(f" Num Epochs = {args.num_train_epochs}")
1423 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1424 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1425 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1426 | logger.info(f" Total optimization steps = {args.max_train_steps}")
1427 | global_step = 0
1428 | first_epoch = 0
1429 |
1430 | if accelerator.is_main_process:
1431 | run_config = vars(args)
1432 | with open(os.path.join(args.output_dir, "run_config.jsonl"), 'a') as f:
1433 | for key, value in run_config.items():
1434 | json.dump({key: value}, f)
1435 | f.write('\n')
1436 |
1437 | # Potentially load in the weights and states from a previous save
1438 | if args.resume_from_checkpoint:
1439 | print(f"output_dir used for resuming is: {args.output_dir}")
1440 | if args.resume_from_checkpoint != "latest":
1441 | path = os.path.basename(args.resume_from_checkpoint)
1442 | else:
1443 | # Get the most recent checkpoint
1444 | dirs = os.listdir(args.output_dir)
1445 | dirs = [d for d in dirs if d.startswith("checkpoint")]
1446 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1447 | path = dirs[-1] if len(dirs) > 0 else None
1448 |
1449 | if path is None:
1450 | accelerator.print(
1451 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1452 | )
1453 | args.resume_from_checkpoint = None
1454 | initial_global_step = 0
1455 | else:
1456 | accelerator.print(f"Resuming from checkpoint {path}")
1457 | accelerator.load_state(os.path.join(args.output_dir, path))
1458 | global_step = int(path.split("-")[1])
1459 |
1460 | initial_global_step = global_step
1461 | first_epoch = global_step // num_update_steps_per_epoch
1462 |
1463 | else:
1464 | initial_global_step = 0
1465 |
1466 | progress_bar = tqdm(
1467 | range(0, args.max_train_steps),
1468 | initial=initial_global_step,
1469 | desc="Steps",
1470 | # Only show the progress bar once on each machine.
1471 | disable=not accelerator.is_local_main_process,
1472 | )
1473 |
1474 | if args.device == "hpu":
1475 | t0 = None
1476 |
1477 | # saving the model before training
1478 | if args.device == "hpu":
1479 | pipeline = GaudiStableDiffusionPipeline.from_pretrained(
1480 | args.pretrained_model_name_or_path,
1481 | text_encoder=unwrap_model(text_encoder) if args.train_text_encoder else text_encoder,
1482 | vae=vae,
1483 | unet=unwrap_model(unet),
1484 | revision=args.revision,
1485 | scheduler=noise_scheduler,
1486 | )
1487 | else:
1488 | pipeline = StableDiffusionPipeline.from_pretrained(
1489 | args.pretrained_model_name_or_path,
1490 | text_encoder=unwrap_model(text_encoder) if args.train_text_encoder else text_encoder,
1491 | vae=vae,
1492 | unet= unwrap_model(unet),
1493 | revision=args.revision,
1494 | )
1495 | pipeline.save_pretrained(args.output_dir)
1496 | if accelerator.is_main_process:
1497 | log_validation(
1498 | vae,
1499 | text_encoder,
1500 | tokenizer,
1501 | unet,
1502 | args,
1503 | accelerator,
1504 | weight_dtype,
1505 | global_step,
1506 | val_dataloader,
1507 | noise_scheduler,
1508 | )
1509 | text_train_active = False
1510 | for epoch in range(first_epoch, args.num_train_epochs):
1511 | unet.train()
1512 | if args.train_text_encoder and global_step > args.freeze_text_encoder_steps:
1513 | text_encoder.train()
1514 | train_loss = 0.0
1515 | print("epoch: ", epoch)
1516 | for step, batch in enumerate(train_dataloader):
1517 | if args.train_text_encoder and global_step > args.freeze_text_encoder_steps and not text_train_active:
1518 | text_encoder.train()
1519 | text_train_active = True
1520 | print("Text encoder training started at {} steps".format(global_step))
1521 |
1522 | if args.device == "hpu":
1523 | if t0 is None and global_step == args.throughput_warmup_steps:
1524 | t0 = time.perf_counter()
1525 |
1526 | with accelerator.accumulate(unet):
1527 | # Convert images to latent space
1528 | latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
1529 | latents = latents * vae.config.scaling_factor
1530 |
1531 | # Sample noise that we'll add to the latents
1532 | noise = torch.randn_like(latents)
1533 | if args.noise_offset:
1534 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise
1535 | noise += args.noise_offset * torch.randn(
1536 | (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
1537 | )
1538 | if args.input_perturbation:
1539 | new_noise = noise + args.input_perturbation * torch.randn_like(noise)
1540 | bsz = latents.shape[0]
1541 | # Sample a random timestep for each image
1542 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
1543 | timesteps = timesteps.long()
1544 |
1545 | # Add noise to the latents according to the noise magnitude at each timestep
1546 | # (this is the forward diffusion process)
1547 | if args.input_perturbation:
1548 | noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
1549 | else:
1550 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1551 |
1552 | # Get the text embedding for conditioning
1553 | encoder_hidden_states = text_encoder(batch["input_ids"])[0]
1554 |
1555 | # Get the target for loss depending on the prediction type
1556 | if args.prediction_type is not None:
1557 | # set prediction_type of scheduler if defined
1558 | noise_scheduler.register_to_config(prediction_type=args.prediction_type)
1559 |
1560 | if noise_scheduler.config.prediction_type == "epsilon":
1561 | target = noise
1562 | elif noise_scheduler.config.prediction_type == "v_prediction":
1563 | target = noise_scheduler.get_velocity(latents, noise, timesteps)
1564 | else:
1565 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1566 |
1567 | # write prediction_type to run_config.json
1568 | if accelerator.is_main_process and epoch == 0 and step == 0:
1569 | with open(os.path.join(args.output_dir, "run_config.jsonl"), 'a') as f:
1570 | f.write(json.dumps({"prediction_type": noise_scheduler.config.prediction_type}) + '\n')
1571 |
1572 | # Predict the noise residual and compute loss
1573 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
1574 |
1575 | if args.snr_gamma is None:
1576 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1577 | else:
1578 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1579 | # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1580 | # This is discussed in Section 4.2 of the same paper.
1581 | snr = compute_snr(noise_scheduler, timesteps)
1582 | if noise_scheduler.config.prediction_type == "v_prediction":
1583 | # Velocity objective requires that we add one to SNR values before we divide by them.
1584 | snr = snr + 1
1585 | mse_loss_weights = (
1586 | torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1587 | )
1588 |
1589 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1590 | loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1591 | loss = loss.mean()
1592 |
1593 | # Gather the losses across all processes for logging (if we use distributed training).
1594 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1595 | train_loss += avg_loss.item() / args.gradient_accumulation_steps
1596 |
1597 | # Backpropagate
1598 | accelerator.backward(loss)
1599 | if accelerator.sync_gradients:
1600 | params_to_clip = (
1601 | itertools.chain(unet.parameters(), text_encoder.parameters())
1602 | if args.train_text_encoder and global_step > args.freeze_text_encoder_steps
1603 | else unet.parameters()
1604 | )
1605 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1606 | optimizer.step()
1607 | lr_scheduler.step()
1608 |
1609 | if args.device == "hpu":
1610 | optimizer.zero_grad(set_to_none=True)
1611 | htcore.mark_step()
1612 | else:
1613 | optimizer.zero_grad()
1614 | # Checks if the accelerator has performed an optimization step behind the scenes
1615 | if accelerator.sync_gradients:
1616 | if args.use_ema:
1617 | ema_unet.step(unet.parameters())
1618 | progress_bar.update(1)
1619 | global_step += 1
1620 | accelerator.log({"training/train_loss": train_loss}, step=global_step)
1621 | accelerator.log({"hyperparameters/batch_size": args.train_batch_size}, step=global_step)
1622 | accelerator.log({"hyperparameters/effective_batch_size": total_batch_size}, step=global_step)
1623 | accelerator.log({"hyperparameters/learning_rate": lr_scheduler.get_last_lr()[0]}, step=global_step)
1624 | train_loss = 0.0
1625 |
1626 | if global_step % args.checkpointing_steps == 0:
1627 | if accelerator.is_main_process:
1628 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1629 | if args.checkpoints_total_limit is not None:
1630 | checkpoints = os.listdir(args.output_dir)
1631 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1632 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1633 |
1634 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1635 | if len(checkpoints) >= args.checkpoints_total_limit:
1636 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1637 | removing_checkpoints = checkpoints[0:num_to_remove]
1638 |
1639 | logger.info(
1640 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1641 | )
1642 | logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1643 |
1644 | for removing_checkpoint in removing_checkpoints:
1645 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1646 | shutil.rmtree(removing_checkpoint)
1647 |
1648 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1649 | accelerator.save_state(save_path)
1650 | logger.info(f"Saved state to {save_path}")
1651 |
1652 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1653 | progress_bar.set_postfix(**logs)
1654 |
1655 | if global_step >= args.max_train_steps:
1656 | break
1657 |
1658 | if args.device == "hpu":
1659 | #gg gaudi addition
1660 | duration = time.perf_counter() - t0
1661 | throughput = args.max_train_steps * total_batch_size / duration
1662 |
1663 | if accelerator.is_main_process:
1664 | if args.device == "hpu":
1665 | logger.info(f"Throughput = {throughput} samples/s")
1666 | logger.info(f"Train runtime = {duration} seconds")
1667 | metrics = {
1668 | "train_samples_per_second": throughput,
1669 | "train_runtime": duration,
1670 | }
1671 | with open(f"{args.output_dir}/speed_metrics.json", mode="w") as file:
1672 | json.dump(metrics, file)
1673 |
1674 | if epoch % args.validation_epochs == 0:
1675 | if args.use_ema:
1676 | # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
1677 | ema_unet.store(unet.parameters())
1678 | ema_unet.copy_to(unet.parameters())
1679 | log_validation(
1680 | vae,
1681 | text_encoder,
1682 | tokenizer,
1683 | unet,
1684 | args,
1685 | accelerator,
1686 | weight_dtype,
1687 | global_step,
1688 | val_dataloader,
1689 | noise_scheduler,
1690 | )
1691 | if args.use_ema:
1692 | # Switch back to the original UNet parameters.
1693 | ema_unet.restore(unet.parameters())
1694 |
1695 | # Create the pipeline using the trained modules and save it.
1696 | accelerator.wait_for_everyone()
1697 | if accelerator.is_main_process:
1698 | if args.device == "hpu":
1699 | logger.info(f"Throughput = {throughput} samples/s")
1700 | logger.info(f"Train runtime = {duration} seconds")
1701 | metrics = {
1702 | "train_samples_per_second": throughput,
1703 | "train_runtime": duration,
1704 | }
1705 | with open(f"{args.output_dir}/speed_metrics.json", mode="w") as file:
1706 | json.dump(metrics, file)
1707 |
1708 | unet = accelerator.unwrap_model(unet)
1709 | if args.train_text_encoder:
1710 | text_encoder = unwrap_model(text_encoder)
1711 | if args.use_ema:
1712 | ema_unet.copy_to(unet.parameters())
1713 |
1714 | if args.device == "hpu":
1715 | pipeline = GaudiStableDiffusionPipeline.from_pretrained(
1716 | args.pretrained_model_name_or_path,
1717 | text_encoder=text_encoder,
1718 | vae=vae,
1719 | unet=unet,
1720 | revision=args.revision,
1721 | scheduler=noise_scheduler,
1722 | )
1723 | else:
1724 | pipeline = StableDiffusionPipeline.from_pretrained(
1725 | args.pretrained_model_name_or_path,
1726 | text_encoder=text_encoder,
1727 | vae=vae,
1728 | unet=unet,
1729 | revision=args.revision,
1730 | )
1731 | pipeline.save_pretrained(args.output_dir)
1732 | if args.use_ema:
1733 | ema_unet.save_pretrained(os.path.join(args.output_dir, "unet_ema"))
1734 |
1735 | # Run a final round of inference.
1736 | images = []
1737 | if args.validation_prompts is not None:
1738 | logger.info("Running inference for collecting generated images...")
1739 | pipeline = pipeline.to(accelerator.device)
1740 | pipeline.torch_dtype = weight_dtype
1741 | pipeline.set_progress_bar_config(disable=True)
1742 |
1743 | if args.enable_xformers_memory_efficient_attention:
1744 | pipeline.enable_xformers_memory_efficient_attention()
1745 |
1746 | if args.seed is None:
1747 | generator = None
1748 | else:
1749 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
1750 |
1751 | for i in range(len(args.validation_prompts)):
1752 | if args.device == "hpu":
1753 | image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
1754 | else:
1755 | with torch.autocast("cuda"):
1756 | image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
1757 | images.append(image)
1758 |
1759 | # name of last directory in output path
1760 | run_name = os.path.basename(os.path.normpath(args.output_dir))
1761 | #save_model_card(args, None, images, repo_folder=args.output_dir)
1762 | if args.push_to_hub:
1763 | #save_model_card(args, repo_id, images, repo_folder=args.output_dir)
1764 | upload_folder(
1765 | repo_id=repo_id,
1766 | folder_path=args.output_dir,
1767 | path_in_repo=f"gpu_runs/{run_name}",
1768 | commit_message="End of training",
1769 | ignore_patterns=["step_*", "epoch_*, checkpoint-*"],
1770 | allow_patterns=["checkpoint-15000"]
1771 | )
1772 |
1773 | accelerator.end_training()
1774 |
1775 |
1776 | if __name__ == "__main__":
1777 | main()
1778 |
--------------------------------------------------------------------------------